Coverage for bilby/bilby_mcmc/flows.py: 88%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import torch 

2from glasflow.nflows.distributions.normal import StandardNormal 

3from glasflow.nflows.flows.base import Flow 

4from glasflow.nflows.nn import nets as nets 

5from glasflow.nflows.transforms import ( 

6 CompositeTransform, 

7 MaskedAffineAutoregressiveTransform, 

8 RandomPermutation, 

9) 

10from glasflow.nflows.transforms.coupling import ( 

11 AdditiveCouplingTransform, 

12 AffineCouplingTransform, 

13) 

14from glasflow.nflows.transforms.normalization import BatchNorm 

15from torch.nn import functional as F 

16 

17# Turn off parallelism 

18torch.set_num_threads(1) 

19torch.set_num_interop_threads(1) 

20 

21 

22class NVPFlow(Flow): 

23 """A simplified version of Real NVP for 1-dim inputs. 

24 

25 This implementation uses 1-dim checkerboard masking but doesn't use 

26 multi-scaling. 

27 Reference: 

28 > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. 

29 

30 This class has been modified from the example found at: 

31 https://github.com/bayesiains/nflows/blob/master/nflows/flows/realnvp.py 

32 """ 

33 

34 def __init__( 

35 self, 

36 features, 

37 hidden_features, 

38 num_layers, 

39 num_blocks_per_layer, 

40 use_volume_preserving=False, 

41 activation=F.relu, 

42 dropout_probability=0.0, 

43 batch_norm_within_layers=False, 

44 batch_norm_between_layers=False, 

45 random_permutation=True, 

46 ): 

47 

48 if use_volume_preserving: 

49 coupling_constructor = AdditiveCouplingTransform 

50 else: 

51 coupling_constructor = AffineCouplingTransform 

52 

53 mask = torch.ones(features) 

54 mask[::2] = -1 

55 

56 def create_resnet(in_features, out_features): 

57 return nets.ResidualNet( 

58 in_features, 

59 out_features, 

60 hidden_features=hidden_features, 

61 num_blocks=num_blocks_per_layer, 

62 activation=activation, 

63 dropout_probability=dropout_probability, 

64 use_batch_norm=batch_norm_within_layers, 

65 ) 

66 

67 layers = [] 

68 for _ in range(num_layers): 

69 transform = coupling_constructor( 

70 mask=mask, transform_net_create_fn=create_resnet 

71 ) 

72 layers.append(transform) 

73 mask *= -1 

74 if batch_norm_between_layers: 

75 layers.append(BatchNorm(features=features)) 

76 

77 if random_permutation: 

78 layers.append(RandomPermutation(features=features)) 

79 

80 super().__init__( 

81 transform=CompositeTransform(layers), 

82 distribution=StandardNormal([features]), 

83 ) 

84 

85 

86class BasicFlow(Flow): 

87 def __init__(self, features): 

88 transform = CompositeTransform( 

89 [ 

90 MaskedAffineAutoregressiveTransform( 

91 features=features, hidden_features=2 * features 

92 ), 

93 RandomPermutation(features=features), 

94 ] 

95 ) 

96 distribution = StandardNormal(shape=[features]) 

97 super().__init__( 

98 transform=transform, 

99 distribution=distribution, 

100 )