Coverage for bilby/core/sampler/cpnest.py: 26%

144 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import array 

2import copy 

3import sys 

4 

5import numpy as np 

6from numpy.lib.recfunctions import structured_to_unstructured 

7from pandas import DataFrame 

8 

9from ..utils import check_directory_exists_and_if_not_mkdir, logger 

10from .base_sampler import NestedSampler, signal_wrapper 

11from .proposal import JumpProposalCycle, Sample 

12 

13 

14class Cpnest(NestedSampler): 

15 """bilby wrapper of cpnest (https://github.com/johnveitch/cpnest) 

16 

17 All positional and keyword arguments (i.e., the args and kwargs) passed to 

18 `run_sampler` will be propagated to `cpnest.CPNest`, see documentation 

19 for that class for further help. Under Other Parameters, we list commonly 

20 used kwargs and the bilby defaults. 

21 

22 Parameters 

23 ========== 

24 nlive: int 

25 The number of live points, note this can also equivalently be given as 

26 one of [npoints, nlives, n_live_points] 

27 seed: int (1234) 

28 Initialised random seed 

29 nthreads: int, (1) 

30 Number of threads to use 

31 maxmcmc: int (1000) 

32 The maximum number of MCMC steps to take 

33 verbose: Bool (True) 

34 If true, print information information about the convergence during 

35 resume: Bool (True) 

36 Whether or not to resume from a previous run 

37 output: str 

38 Where to write the CPNest, by default this is 

39 {self.outdir}/cpnest_{self.label}/ 

40 

41 """ 

42 

43 sampler_name = "cpnest" 

44 default_kwargs = dict( 

45 verbose=3, 

46 nthreads=1, 

47 nlive=500, 

48 maxmcmc=1000, 

49 seed=None, 

50 poolsize=100, 

51 nhamiltonian=0, 

52 resume=True, 

53 output=None, 

54 proposals=None, 

55 n_periodic_checkpoint=8000, 

56 ) 

57 hard_exit = True 

58 sampling_seed_key = "seed" 

59 

60 def _translate_kwargs(self, kwargs): 

61 kwargs = super()._translate_kwargs(kwargs) 

62 if "nlive" not in kwargs: 

63 for equiv in self.npoints_equiv_kwargs: 

64 if equiv in kwargs: 

65 kwargs["nlive"] = kwargs.pop(equiv) 

66 if "nthreads" not in kwargs: 

67 for equiv in self.npool_equiv_kwargs: 

68 if equiv in kwargs: 

69 kwargs["nthreads"] = kwargs.pop(equiv) 

70 

71 if "seed" not in kwargs: 

72 logger.warning("No seed provided, cpnest will use 1234.") 

73 

74 @signal_wrapper 

75 def run_sampler(self): 

76 from cpnest import CPNest 

77 from cpnest import model as cpmodel 

78 from cpnest.nest2pos import compute_weights 

79 from cpnest.parameter import LivePoint 

80 

81 class Model(cpmodel.Model): 

82 """A wrapper class to pass our log_likelihood into cpnest""" 

83 

84 def __init__(self, names, priors): 

85 self.names = names 

86 self.priors = priors 

87 self._update_bounds() 

88 

89 @staticmethod 

90 def log_likelihood(x, **kwargs): 

91 theta = [x[n] for n in self.search_parameter_keys] 

92 return self.log_likelihood(theta) 

93 

94 @staticmethod 

95 def log_prior(x, **kwargs): 

96 theta = [x[n] for n in self.search_parameter_keys] 

97 return self.log_prior(theta) 

98 

99 def _update_bounds(self): 

100 self.bounds = [ 

101 [self.priors[key].minimum, self.priors[key].maximum] 

102 for key in self.names 

103 ] 

104 

105 def new_point(self): 

106 """Draw a point from the prior""" 

107 prior_samples = self.priors.sample() 

108 self._update_bounds() 

109 point = LivePoint( 

110 self.names, 

111 array.array("d", [prior_samples[name] for name in self.names]), 

112 ) 

113 return point 

114 

115 self._resolve_proposal_functions() 

116 model = Model(self.search_parameter_keys, self.priors) 

117 out = None 

118 remove_kwargs = ["proposals", "n_periodic_checkpoint"] 

119 while out is None: 

120 try: 

121 out = CPNest(model, **self.kwargs) 

122 except TypeError as e: 

123 if len(remove_kwargs) > 0: 

124 kwarg = remove_kwargs.pop(0) 

125 else: 

126 raise TypeError("Unable to initialise cpnest sampler") 

127 logger.info(f"CPNest init. failed with error {e}, please update") 

128 logger.info(f"Attempting to rerun with kwarg {kwarg} removed") 

129 self.kwargs.pop(kwarg) 

130 try: 

131 out.run() 

132 except SystemExit: 

133 out.checkpoint() 

134 self.write_current_state_and_exit() 

135 

136 if self.plot: 

137 out.plot() 

138 

139 self.calc_likelihood_count() 

140 self.result.samples = structured_to_unstructured( 

141 out.posterior_samples[self.search_parameter_keys] 

142 ) 

143 self.result.log_likelihood_evaluations = out.posterior_samples["logL"] 

144 self.result.nested_samples = DataFrame(out.get_nested_samples(filename="")) 

145 self.result.nested_samples.rename( 

146 columns=dict(logL="log_likelihood"), inplace=True 

147 ) 

148 _, log_weights = compute_weights( 

149 np.array(self.result.nested_samples.log_likelihood), 

150 np.array(out.NS.state.nlive), 

151 ) 

152 self.result.nested_samples["weights"] = np.exp(log_weights) 

153 self.result.log_evidence = out.NS.state.logZ 

154 self.result.log_evidence_err = np.sqrt(out.NS.state.info / out.NS.state.nlive) 

155 self.result.information_gain = out.NS.state.info 

156 return self.result 

157 

158 def write_current_state_and_exit(self, signum=None, frame=None): 

159 """ 

160 Overwrites the base class to make sure that :code:`CPNest` terminates 

161 properly as :code:`CPNest` handles all the multiprocessing internally. 

162 """ 

163 self._log_interruption(signum=signum) 

164 sys.exit(self.exit_code) 

165 

166 def _verify_kwargs_against_default_kwargs(self): 

167 """ 

168 Set the directory where the output will be written 

169 and check resume and checkpoint status. 

170 """ 

171 if not self.kwargs["output"]: 

172 self.kwargs["output"] = f"{self.outdir}/cpnest_{self.label}/" 

173 if self.kwargs["output"].endswith("/") is False: 

174 self.kwargs["output"] = f"{self.kwargs['output']}/" 

175 check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) 

176 if self.kwargs["n_periodic_checkpoint"] and not self.kwargs["resume"]: 

177 self.kwargs["n_periodic_checkpoint"] = None 

178 NestedSampler._verify_kwargs_against_default_kwargs(self) 

179 

180 def _resolve_proposal_functions(self): 

181 from cpnest.proposal import ProposalCycle 

182 

183 if "proposals" in self.kwargs: 

184 if self.kwargs["proposals"] is None: 

185 return 

186 if isinstance(self.kwargs["proposals"], JumpProposalCycle): 

187 self.kwargs["proposals"] = dict( 

188 mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"] 

189 ) 

190 for key, proposal in self.kwargs["proposals"].items(): 

191 if isinstance(proposal, JumpProposalCycle): 

192 self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory( 

193 proposal 

194 ) 

195 elif isinstance(proposal, ProposalCycle): 

196 pass 

197 else: 

198 raise TypeError("Unknown proposal type") 

199 

200 

201def cpnest_proposal_factory(jump_proposal): 

202 import cpnest.proposal 

203 

204 class CPNestEnsembleProposal(cpnest.proposal.EnsembleProposal): 

205 def __init__(self, jp): 

206 self.jump_proposal = jp 

207 self.ensemble = None 

208 

209 def __call__(self, sample, **kwargs): 

210 return self.get_sample(sample, **kwargs) 

211 

212 def get_sample(self, cpnest_sample, **kwargs): 

213 sample = Sample.from_cpnest_live_point(cpnest_sample) 

214 self.ensemble = kwargs.get("coordinates", self.ensemble) 

215 sample = self.jump_proposal(sample=sample, sampler_name="cpnest", **kwargs) 

216 self.log_J = self.jump_proposal.log_j 

217 return self._update_cpnest_sample(cpnest_sample, sample) 

218 

219 @staticmethod 

220 def _update_cpnest_sample(cpnest_sample, sample): 

221 cpnest_sample.names = list(sample.keys()) 

222 for i, value in enumerate(sample.values()): 

223 cpnest_sample.values[i] = value 

224 return cpnest_sample 

225 

226 return CPNestEnsembleProposal(jump_proposal) 

227 

228 

229def cpnest_proposal_cycle_factory(jump_proposals): 

230 import cpnest.proposal 

231 

232 class CPNestProposalCycle(cpnest.proposal.ProposalCycle): 

233 def __init__(self): 

234 self.jump_proposals = copy.deepcopy(jump_proposals) 

235 for i, prop in enumerate(self.jump_proposals.proposal_functions): 

236 self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory( 

237 prop 

238 ) 

239 self.jump_proposals.update_cycle() 

240 super(CPNestProposalCycle, self).__init__( 

241 proposals=self.jump_proposals.proposal_functions, 

242 weights=self.jump_proposals.weights, 

243 cyclelength=self.jump_proposals.cycle_length, 

244 ) 

245 

246 def get_sample(self, old, **kwargs): 

247 return self.jump_proposals(sample=old, coordinates=self.ensemble, **kwargs) 

248 

249 def set_ensemble(self, ensemble): 

250 self.ensemble = ensemble 

251 

252 return CPNestProposalCycle