Coverage for parallel_bilby/generation.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

84 statements  

1""" 

2Module to generate/prepare data, likelihood, and priors for parallel runs. 

3 

4This will create a directory structure for your parallel runs to store the 

5output files, logs and plots. It will also generate a `data_dump` that stores 

6information on the run settings and data to be analysed. 

7""" 

8import os 

9import pickle 

10import subprocess 

11from argparse import Namespace 

12 

13import bilby 

14import bilby_pipe 

15import bilby_pipe.data_generation 

16import dynesty 

17import lalsimulation 

18import numpy as np 

19 

20from . import __version__, slurm 

21from .parser import create_generation_parser, parse_generation_args 

22from .utils import get_cli_args 

23 

24 

25def get_version_info(): 

26 return dict( 

27 bilby_version=bilby.__version__, 

28 bilby_pipe_version=bilby_pipe.__version__, 

29 parallel_bilby_version=__version__, 

30 dynesty_version=dynesty.__version__, 

31 lalsimulation_version=lalsimulation.__version__, 

32 ) 

33 

34 

35def write_complete_config_file(parser, args, inputs): 

36 """Wrapper function that uses bilby_pipe's complete config writer. 

37 

38 Note: currently this function does not verify that the written complete config is 

39 identical to the source config 

40 

41 :param parser: The argparse.ArgumentParser to parse user input 

42 :param args: The parsed user input in a Namespace object 

43 :param inputs: The bilby_pipe.input.Input object storing user args 

44 :return: None 

45 """ 

46 inputs.request_cpus = 1 

47 inputs.sampler_kwargs = "{}" 

48 inputs.mpi_timing_interval = 0 

49 inputs.log_directory = None 

50 try: 

51 bilby_pipe.main.write_complete_config_file(parser, args, inputs) 

52 except AttributeError: 

53 # bilby_pipe expects the ini to have "online_pe" and some other non pBilby args 

54 pass 

55 

56 

57def create_generation_logger(outdir, label): 

58 logger = bilby.core.utils.logger 

59 bilby.core.utils.setup_logger( 

60 outdir=os.path.join(outdir, "log_data_generation"), label=label 

61 ) 

62 bilby_pipe.data_generation.logger = logger 

63 return logger 

64 

65 

66class ParallelBilbyDataGenerationInput(bilby_pipe.data_generation.DataGenerationInput): 

67 def __init__(self, args, unknown_args): 

68 super().__init__(args, unknown_args) 

69 self.args = args 

70 self.sampler = "dynesty" 

71 self.sampling_seed = args.sampling_seed 

72 self.data_dump_file = f"{self.data_directory}/{self.label}_data_dump.pickle" 

73 self.setup_inputs() 

74 

75 @property 

76 def sampling_seed(self): 

77 return self._samplng_seed 

78 

79 @sampling_seed.setter 

80 def sampling_seed(self, sampling_seed): 

81 if sampling_seed is None: 

82 sampling_seed = np.random.randint(1, 1e6) 

83 self._samplng_seed = sampling_seed 

84 np.random.seed(sampling_seed) 

85 

86 def save_data_dump(self): 

87 with open(self.data_dump_file, "wb+") as file: 

88 data_dump = dict( 

89 waveform_generator=self.waveform_generator, 

90 ifo_list=self.interferometers, 

91 prior_file=self.prior_file, 

92 args=self.args, 

93 data_dump_file=self.data_dump_file, 

94 meta_data=self.meta_data, 

95 injection_parameters=self.injection_parameters, 

96 ) 

97 pickle.dump(data_dump, file) 

98 

99 def setup_inputs(self): 

100 if self.likelihood_type == "ROQGravitationalWaveTransient": 

101 self.save_roq_weights() 

102 self.interferometers.plot_data(outdir=self.data_directory, label=self.label) 

103 

104 # This is done before instantiating the likelihood so that it is the full prior 

105 self.priors.to_json(outdir=self.data_directory, label=self.label) 

106 self.prior_file = f"{self.data_directory}/{self.label}_prior.json" 

107 

108 # We build the likelihood here to ensure the distance marginalization exist 

109 # before sampling 

110 self.likelihood 

111 

112 self.meta_data.update( 

113 dict( 

114 config_file=self.ini, 

115 data_dump_file=self.data_dump_file, 

116 **get_version_info(), 

117 ) 

118 ) 

119 

120 self.save_data_dump() 

121 

122 

123def generate_runner(parser=None, **kwargs): 

124 """ 

125 API for running the generation from Python instead of the command line. 

126 It takes all the same options as the CLI, specified as keyword arguments, 

127 and combines them with the defaults in the parser. 

128 

129 Parameters 

130 ---------- 

131 parser: generation-parser 

132 **kwargs: 

133 Any keyword arguments that can be specified via the CLI 

134 

135 Returns 

136 ------- 

137 inputs: ParallelBilbyDataGenerationInput 

138 logger: bilby.core.utils.logger 

139 

140 """ 

141 

142 # Create a dummy parser if necessary 

143 if parser is None: 

144 parser = create_generation_parser() 

145 

146 # Get default arguments from the parser 

147 default_args = parse_generation_args(parser) 

148 

149 # Take the union of default_args and any input arguments, 

150 # and turn it into a Namespace 

151 args = Namespace(**{**default_args, **kwargs}) 

152 

153 logger = create_generation_logger(outdir=args.outdir, label=args.label) 

154 for package, version in get_version_info().items(): 

155 logger.info(f"{package} version: {version}") 

156 inputs = ParallelBilbyDataGenerationInput(args, []) 

157 logger.info( 

158 "Setting up likelihood with marginalizations: " 

159 f"distance={inputs.distance_marginalization}, " 

160 f"time={inputs.time_marginalization}, " 

161 f"phase={inputs.phase_marginalization}." 

162 ) 

163 logger.info(f"Setting sampling-seed={inputs.sampling_seed}") 

164 logger.info(f"prior-file save at {inputs.prior_file}") 

165 logger.info( 

166 f"Initial meta_data =" 

167 f"{bilby_pipe.utils.pretty_print_dictionary(inputs.meta_data)}" 

168 ) 

169 

170 write_complete_config_file(parser=parser, args=args, inputs=inputs) 

171 logger.info(f"Complete ini written: {inputs.complete_ini_file}") 

172 

173 return inputs, logger 

174 

175 

176def main(): 

177 """ 

178 paralell_bilby_generation entrypoint. 

179 

180 This function is a wrapper around generate_runner(), 

181 giving it a command line interface. 

182 """ 

183 

184 # Parse command line arguments 

185 cli_args = get_cli_args() 

186 generation_parser = create_generation_parser() 

187 args = parse_generation_args(generation_parser, cli_args, as_namespace=True) 

188 

189 # Initialise run 

190 inputs, logger = generate_runner(parser=generation_parser, **vars(args)) 

191 

192 # Write slurm script 

193 bash_file = slurm.setup_submit(inputs.data_dump_file, inputs, args, cli_args) 

194 if args.submit: 

195 subprocess.run([f"bash {bash_file}"], shell=True) 

196 else: 

197 logger.info(f"Setup complete, now run:\n $ bash {bash_file}")