Coverage for parallel_bilby/analysis/likelihood.py: 64%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

56 statements  

1import inspect 

2from importlib import import_module 

3 

4import bilby 

5import bilby_pipe 

6import numpy as np 

7from bilby.core.utils import logger 

8 

9 

10def reorder_loglikelihoods(unsorted_loglikelihoods, unsorted_samples, sorted_samples): 

11 """Reorders the stored log-likelihood after they have been reweighted 

12 

13 This creates a sorting index by matching the reweights `result.samples` 

14 against the raw samples, then uses this index to sort the 

15 loglikelihoods 

16 

17 Parameters 

18 ---------- 

19 sorted_samples, unsorted_samples: array-like 

20 Sorted and unsorted values of the samples. These should be of the 

21 same shape and contain the same sample values, but in different 

22 orders 

23 unsorted_loglikelihoods: array-like 

24 The loglikelihoods corresponding to the unsorted_samples 

25 

26 Returns 

27 ------- 

28 sorted_loglikelihoods: array-like 

29 The loglikelihoods reordered to match that of the sorted_samples 

30 

31 

32 """ 

33 

34 idxs = [] 

35 for ii in range(len(unsorted_loglikelihoods)): 

36 idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0] 

37 if len(idx) > 1: 

38 print( 

39 "Multiple likelihood matches found between sorted and " 

40 "unsorted samples. Taking the first match." 

41 ) 

42 idxs.append(idx[0]) 

43 return unsorted_loglikelihoods[idxs] 

44 

45 

46def roq_likelihood_kwargs(args): 

47 """Return the kwargs required for the ROQ setup 

48 

49 Parameters 

50 ---------- 

51 args: Namespace 

52 The parser arguments 

53 

54 Returns 

55 ------- 

56 kwargs: dict 

57 A dictionary of the required kwargs 

58 

59 """ 

60 

61 kwargs = dict( 

62 weights=None, 

63 roq_params=None, 

64 linear_matrix=None, 

65 quadratic_matrix=None, 

66 roq_scale_factor=args.roq_scale_factor, 

67 ) 

68 if hasattr(args, "likelihood_roq_params") and hasattr( 

69 args, "likelihood_roq_weights" 

70 ): 

71 kwargs["roq_params"] = args.likelihood_roq_params 

72 kwargs["weights"] = args.likelihood_roq_weights 

73 elif hasattr(args, "roq_folder") and args.roq_folder is not None: 

74 logger.info(f"Loading ROQ weights from {args.roq_folder}, {args.weight_file}") 

75 kwargs["roq_params"] = np.genfromtxt( 

76 args.roq_folder + "/params.dat", names=True 

77 ) 

78 kwargs["weights"] = args.weight_file 

79 elif hasattr(args, "roq_linear_matrix") and args.roq_linear_matrix is not None: 

80 logger.info(f"Loading linear_matrix from {args.roq_linear_matrix}") 

81 logger.info(f"Loading quadratic_matrix from {args.roq_quadratic_matrix}") 

82 kwargs["linear_matrix"] = args.roq_linear_matrix 

83 kwargs["quadratic_matrix"] = args.roq_quadratic_matrix 

84 return kwargs 

85 

86 

87def setup_likelihood(interferometers, waveform_generator, priors, args): 

88 """Takes the kwargs and sets up and returns either an ROQ GW or GW likelihood. 

89 

90 Parameters 

91 ---------- 

92 interferometers: bilby.gw.detectors.InterferometerList 

93 The pre-loaded bilby IFO 

94 waveform_generator: bilby.gw.waveform_generator.LALCBCWaveformGenerator 

95 The waveform generation 

96 priors: dict 

97 The priors, used for setting up marginalization 

98 args: Namespace 

99 The parser arguments 

100 

101 

102 Returns 

103 ------- 

104 likelihood: bilby.gw.likelihood.GravitationalWaveTransient 

105 The likelihood (either GravitationalWaveTransient or ROQGravitationalWaveTransient) 

106 

107 """ 

108 

109 likelihood_kwargs = dict( 

110 interferometers=interferometers, 

111 waveform_generator=waveform_generator, 

112 priors=priors, 

113 phase_marginalization=args.phase_marginalization, 

114 distance_marginalization=args.distance_marginalization, 

115 distance_marginalization_lookup_table=args.distance_marginalization_lookup_table, 

116 time_marginalization=args.time_marginalization, 

117 reference_frame=args.reference_frame, 

118 time_reference=args.time_reference, 

119 ) 

120 

121 if args.likelihood_type == "GravitationalWaveTransient": 

122 Likelihood = bilby.gw.likelihood.GravitationalWaveTransient 

123 likelihood_kwargs.update(jitter_time=args.jitter_time) 

124 

125 elif args.likelihood_type == "ROQGravitationalWaveTransient": 

126 Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient 

127 

128 if args.time_marginalization: 

129 logger.warning( 

130 "Time marginalization not implemented for " 

131 "ROQGravitationalWaveTransient: option ignored" 

132 ) 

133 

134 likelihood_kwargs.pop("time_marginalization", None) 

135 likelihood_kwargs.pop("jitter_time", None) 

136 likelihood_kwargs.update(roq_likelihood_kwargs(args)) 

137 elif "." in args.likelihood_type: 

138 split_path = args.likelihood_type.split(".") 

139 module = ".".join(split_path[:-1]) 

140 likelihood_class = split_path[-1] 

141 Likelihood = getattr(import_module(module), likelihood_class) 

142 likelihood_kwargs.update( 

143 bilby_pipe.utils.convert_string_to_dict(args.extra_likelihood_kwargs) 

144 ) 

145 if "roq" in args.likelihood_type.lower(): 

146 likelihood_kwargs.pop("time_marginalization", None) 

147 likelihood_kwargs.pop("jitter_time", None) 

148 likelihood_kwargs.update(args.roq_likelihood_kwargs) 

149 else: 

150 raise ValueError("Unknown Likelihood class {}") 

151 

152 likelihood_kwargs = { 

153 key: likelihood_kwargs[key] 

154 for key in likelihood_kwargs 

155 if key in inspect.getfullargspec(Likelihood.__init__).args 

156 } 

157 

158 logger.info( 

159 f"Initialise likelihood {Likelihood} with kwargs: \n{likelihood_kwargs}" 

160 ) 

161 

162 likelihood = Likelihood(**likelihood_kwargs) 

163 return likelihood