Coverage for parallel_bilby/generation.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Module to generate/prepare data, likelihood, and priors for parallel runs.
4This will create a directory structure for your parallel runs to store the
5output files, logs and plots. It will also generate a `data_dump` that stores
6information on the run settings and data to be analysed.
7"""
8import os
9import pickle
10import subprocess
11from argparse import Namespace
13import bilby
14import bilby_pipe
15import bilby_pipe.data_generation
16import dynesty
17import lalsimulation
18import numpy as np
20from . import __version__, slurm
21from .parser import create_generation_parser, parse_generation_args
22from .utils import get_cli_args
25def get_version_info():
26 return dict(
27 bilby_version=bilby.__version__,
28 bilby_pipe_version=bilby_pipe.__version__,
29 parallel_bilby_version=__version__,
30 dynesty_version=dynesty.__version__,
31 lalsimulation_version=lalsimulation.__version__,
32 )
35def write_complete_config_file(parser, args, inputs):
36 """Wrapper function that uses bilby_pipe's complete config writer.
38 Note: currently this function does not verify that the written complete config is
39 identical to the source config
41 :param parser: The argparse.ArgumentParser to parse user input
42 :param args: The parsed user input in a Namespace object
43 :param inputs: The bilby_pipe.input.Input object storing user args
44 :return: None
45 """
46 inputs.request_cpus = 1
47 inputs.sampler_kwargs = "{}"
48 inputs.mpi_timing_interval = 0
49 inputs.log_directory = None
50 try:
51 bilby_pipe.main.write_complete_config_file(parser, args, inputs)
52 except AttributeError:
53 # bilby_pipe expects the ini to have "online_pe" and some other non pBilby args
54 pass
57def create_generation_logger(outdir, label):
58 logger = bilby.core.utils.logger
59 bilby.core.utils.setup_logger(
60 outdir=os.path.join(outdir, "log_data_generation"), label=label
61 )
62 bilby_pipe.data_generation.logger = logger
63 return logger
66class ParallelBilbyDataGenerationInput(bilby_pipe.data_generation.DataGenerationInput):
67 def __init__(self, args, unknown_args):
68 super().__init__(args, unknown_args)
69 self.args = args
70 self.sampler = "dynesty"
71 self.sampling_seed = args.sampling_seed
72 self.data_dump_file = f"{self.data_directory}/{self.label}_data_dump.pickle"
73 self.setup_inputs()
75 @property
76 def sampling_seed(self):
77 return self._samplng_seed
79 @sampling_seed.setter
80 def sampling_seed(self, sampling_seed):
81 if sampling_seed is None:
82 sampling_seed = np.random.randint(1, 1e6)
83 self._samplng_seed = sampling_seed
84 np.random.seed(sampling_seed)
86 def save_data_dump(self):
87 with open(self.data_dump_file, "wb+") as file:
88 data_dump = dict(
89 waveform_generator=self.waveform_generator,
90 ifo_list=self.interferometers,
91 prior_file=self.prior_file,
92 args=self.args,
93 data_dump_file=self.data_dump_file,
94 meta_data=self.meta_data,
95 injection_parameters=self.injection_parameters,
96 )
97 pickle.dump(data_dump, file)
99 def setup_inputs(self):
100 if self.likelihood_type == "ROQGravitationalWaveTransient":
101 self.save_roq_weights()
102 self.interferometers.plot_data(outdir=self.data_directory, label=self.label)
104 # This is done before instantiating the likelihood so that it is the full prior
105 self.priors.to_json(outdir=self.data_directory, label=self.label)
106 self.prior_file = f"{self.data_directory}/{self.label}_prior.json"
108 # We build the likelihood here to ensure the distance marginalization exist
109 # before sampling
110 self.likelihood
112 self.meta_data.update(
113 dict(
114 config_file=self.ini,
115 data_dump_file=self.data_dump_file,
116 **get_version_info(),
117 )
118 )
120 self.save_data_dump()
123def generate_runner(parser=None, **kwargs):
124 """
125 API for running the generation from Python instead of the command line.
126 It takes all the same options as the CLI, specified as keyword arguments,
127 and combines them with the defaults in the parser.
129 Parameters
130 ----------
131 parser: generation-parser
132 **kwargs:
133 Any keyword arguments that can be specified via the CLI
135 Returns
136 -------
137 inputs: ParallelBilbyDataGenerationInput
138 logger: bilby.core.utils.logger
140 """
142 # Create a dummy parser if necessary
143 if parser is None:
144 parser = create_generation_parser()
146 # Get default arguments from the parser
147 default_args = parse_generation_args(parser)
149 # Take the union of default_args and any input arguments,
150 # and turn it into a Namespace
151 args = Namespace(**{**default_args, **kwargs})
153 logger = create_generation_logger(outdir=args.outdir, label=args.label)
154 for package, version in get_version_info().items():
155 logger.info(f"{package} version: {version}")
156 inputs = ParallelBilbyDataGenerationInput(args, [])
157 logger.info(
158 "Setting up likelihood with marginalizations: "
159 f"distance={inputs.distance_marginalization}, "
160 f"time={inputs.time_marginalization}, "
161 f"phase={inputs.phase_marginalization}."
162 )
163 logger.info(f"Setting sampling-seed={inputs.sampling_seed}")
164 logger.info(f"prior-file save at {inputs.prior_file}")
165 logger.info(
166 f"Initial meta_data ="
167 f"{bilby_pipe.utils.pretty_print_dictionary(inputs.meta_data)}"
168 )
170 write_complete_config_file(parser=parser, args=args, inputs=inputs)
171 logger.info(f"Complete ini written: {inputs.complete_ini_file}")
173 return inputs, logger
176def main():
177 """
178 paralell_bilby_generation entrypoint.
180 This function is a wrapper around generate_runner(),
181 giving it a command line interface.
182 """
184 # Parse command line arguments
185 cli_args = get_cli_args()
186 generation_parser = create_generation_parser()
187 args = parse_generation_args(generation_parser, cli_args, as_namespace=True)
189 # Initialise run
190 inputs, logger = generate_runner(parser=generation_parser, **vars(args))
192 # Write slurm script
193 bash_file = slurm.setup_submit(inputs.data_dump_file, inputs, args, cli_args)
194 if args.submit:
195 subprocess.run([f"bash {bash_file}"], shell=True)
196 else:
197 logger.info(f"Setup complete, now run:\n $ bash {bash_file}")