Coverage for bilby/bilby_mcmc/flows.py: 88%
34 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import torch
2from glasflow.nflows.distributions.normal import StandardNormal
3from glasflow.nflows.flows.base import Flow
4from glasflow.nflows.nn import nets as nets
5from glasflow.nflows.transforms import (
6 CompositeTransform,
7 MaskedAffineAutoregressiveTransform,
8 RandomPermutation,
9)
10from glasflow.nflows.transforms.coupling import (
11 AdditiveCouplingTransform,
12 AffineCouplingTransform,
13)
14from glasflow.nflows.transforms.normalization import BatchNorm
15from torch.nn import functional as F
17# Turn off parallelism
18torch.set_num_threads(1)
19torch.set_num_interop_threads(1)
22class NVPFlow(Flow):
23 """A simplified version of Real NVP for 1-dim inputs.
25 This implementation uses 1-dim checkerboard masking but doesn't use
26 multi-scaling.
27 Reference:
28 > L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
30 This class has been modified from the example found at:
31 https://github.com/bayesiains/nflows/blob/master/nflows/flows/realnvp.py
32 """
34 def __init__(
35 self,
36 features,
37 hidden_features,
38 num_layers,
39 num_blocks_per_layer,
40 use_volume_preserving=False,
41 activation=F.relu,
42 dropout_probability=0.0,
43 batch_norm_within_layers=False,
44 batch_norm_between_layers=False,
45 random_permutation=True,
46 ):
48 if use_volume_preserving:
49 coupling_constructor = AdditiveCouplingTransform
50 else:
51 coupling_constructor = AffineCouplingTransform
53 mask = torch.ones(features)
54 mask[::2] = -1
56 def create_resnet(in_features, out_features):
57 return nets.ResidualNet(
58 in_features,
59 out_features,
60 hidden_features=hidden_features,
61 num_blocks=num_blocks_per_layer,
62 activation=activation,
63 dropout_probability=dropout_probability,
64 use_batch_norm=batch_norm_within_layers,
65 )
67 layers = []
68 for _ in range(num_layers):
69 transform = coupling_constructor(
70 mask=mask, transform_net_create_fn=create_resnet
71 )
72 layers.append(transform)
73 mask *= -1
74 if batch_norm_between_layers:
75 layers.append(BatchNorm(features=features))
77 if random_permutation:
78 layers.append(RandomPermutation(features=features))
80 super().__init__(
81 transform=CompositeTransform(layers),
82 distribution=StandardNormal([features]),
83 )
86class BasicFlow(Flow):
87 def __init__(self, features):
88 transform = CompositeTransform(
89 [
90 MaskedAffineAutoregressiveTransform(
91 features=features, hidden_features=2 * features
92 ),
93 RandomPermutation(features=features),
94 ]
95 )
96 distribution = StandardNormal(shape=[features])
97 super().__init__(
98 transform=transform,
99 distribution=distribution,
100 )