Coverage for bilby/core/sampler/cpnest.py: 26%
144 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import array
2import copy
3import sys
5import numpy as np
6from numpy.lib.recfunctions import structured_to_unstructured
7from pandas import DataFrame
9from ..utils import check_directory_exists_and_if_not_mkdir, logger
10from .base_sampler import NestedSampler, signal_wrapper
11from .proposal import JumpProposalCycle, Sample
14class Cpnest(NestedSampler):
15 """bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
17 All positional and keyword arguments (i.e., the args and kwargs) passed to
18 `run_sampler` will be propagated to `cpnest.CPNest`, see documentation
19 for that class for further help. Under Other Parameters, we list commonly
20 used kwargs and the bilby defaults.
22 Parameters
23 ==========
24 nlive: int
25 The number of live points, note this can also equivalently be given as
26 one of [npoints, nlives, n_live_points]
27 seed: int (1234)
28 Initialised random seed
29 nthreads: int, (1)
30 Number of threads to use
31 maxmcmc: int (1000)
32 The maximum number of MCMC steps to take
33 verbose: Bool (True)
34 If true, print information information about the convergence during
35 resume: Bool (True)
36 Whether or not to resume from a previous run
37 output: str
38 Where to write the CPNest, by default this is
39 {self.outdir}/cpnest_{self.label}/
41 """
43 sampler_name = "cpnest"
44 default_kwargs = dict(
45 verbose=3,
46 nthreads=1,
47 nlive=500,
48 maxmcmc=1000,
49 seed=None,
50 poolsize=100,
51 nhamiltonian=0,
52 resume=True,
53 output=None,
54 proposals=None,
55 n_periodic_checkpoint=8000,
56 )
57 hard_exit = True
58 sampling_seed_key = "seed"
60 def _translate_kwargs(self, kwargs):
61 kwargs = super()._translate_kwargs(kwargs)
62 if "nlive" not in kwargs:
63 for equiv in self.npoints_equiv_kwargs:
64 if equiv in kwargs:
65 kwargs["nlive"] = kwargs.pop(equiv)
66 if "nthreads" not in kwargs:
67 for equiv in self.npool_equiv_kwargs:
68 if equiv in kwargs:
69 kwargs["nthreads"] = kwargs.pop(equiv)
71 if "seed" not in kwargs:
72 logger.warning("No seed provided, cpnest will use 1234.")
74 @signal_wrapper
75 def run_sampler(self):
76 from cpnest import CPNest
77 from cpnest import model as cpmodel
78 from cpnest.nest2pos import compute_weights
79 from cpnest.parameter import LivePoint
81 class Model(cpmodel.Model):
82 """A wrapper class to pass our log_likelihood into cpnest"""
84 def __init__(self, names, priors):
85 self.names = names
86 self.priors = priors
87 self._update_bounds()
89 @staticmethod
90 def log_likelihood(x, **kwargs):
91 theta = [x[n] for n in self.search_parameter_keys]
92 return self.log_likelihood(theta)
94 @staticmethod
95 def log_prior(x, **kwargs):
96 theta = [x[n] for n in self.search_parameter_keys]
97 return self.log_prior(theta)
99 def _update_bounds(self):
100 self.bounds = [
101 [self.priors[key].minimum, self.priors[key].maximum]
102 for key in self.names
103 ]
105 def new_point(self):
106 """Draw a point from the prior"""
107 prior_samples = self.priors.sample()
108 self._update_bounds()
109 point = LivePoint(
110 self.names,
111 array.array("d", [prior_samples[name] for name in self.names]),
112 )
113 return point
115 self._resolve_proposal_functions()
116 model = Model(self.search_parameter_keys, self.priors)
117 out = None
118 remove_kwargs = ["proposals", "n_periodic_checkpoint"]
119 while out is None:
120 try:
121 out = CPNest(model, **self.kwargs)
122 except TypeError as e:
123 if len(remove_kwargs) > 0:
124 kwarg = remove_kwargs.pop(0)
125 else:
126 raise TypeError("Unable to initialise cpnest sampler")
127 logger.info(f"CPNest init. failed with error {e}, please update")
128 logger.info(f"Attempting to rerun with kwarg {kwarg} removed")
129 self.kwargs.pop(kwarg)
130 try:
131 out.run()
132 except SystemExit:
133 out.checkpoint()
134 self.write_current_state_and_exit()
136 if self.plot:
137 out.plot()
139 self.calc_likelihood_count()
140 self.result.samples = structured_to_unstructured(
141 out.posterior_samples[self.search_parameter_keys]
142 )
143 self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
144 self.result.nested_samples = DataFrame(out.get_nested_samples(filename=""))
145 self.result.nested_samples.rename(
146 columns=dict(logL="log_likelihood"), inplace=True
147 )
148 _, log_weights = compute_weights(
149 np.array(self.result.nested_samples.log_likelihood),
150 np.array(out.NS.state.nlive),
151 )
152 self.result.nested_samples["weights"] = np.exp(log_weights)
153 self.result.log_evidence = out.NS.state.logZ
154 self.result.log_evidence_err = np.sqrt(out.NS.state.info / out.NS.state.nlive)
155 self.result.information_gain = out.NS.state.info
156 return self.result
158 def write_current_state_and_exit(self, signum=None, frame=None):
159 """
160 Overwrites the base class to make sure that :code:`CPNest` terminates
161 properly as :code:`CPNest` handles all the multiprocessing internally.
162 """
163 self._log_interruption(signum=signum)
164 sys.exit(self.exit_code)
166 def _verify_kwargs_against_default_kwargs(self):
167 """
168 Set the directory where the output will be written
169 and check resume and checkpoint status.
170 """
171 if not self.kwargs["output"]:
172 self.kwargs["output"] = f"{self.outdir}/cpnest_{self.label}/"
173 if self.kwargs["output"].endswith("/") is False:
174 self.kwargs["output"] = f"{self.kwargs['output']}/"
175 check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
176 if self.kwargs["n_periodic_checkpoint"] and not self.kwargs["resume"]:
177 self.kwargs["n_periodic_checkpoint"] = None
178 NestedSampler._verify_kwargs_against_default_kwargs(self)
180 def _resolve_proposal_functions(self):
181 from cpnest.proposal import ProposalCycle
183 if "proposals" in self.kwargs:
184 if self.kwargs["proposals"] is None:
185 return
186 if isinstance(self.kwargs["proposals"], JumpProposalCycle):
187 self.kwargs["proposals"] = dict(
188 mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"]
189 )
190 for key, proposal in self.kwargs["proposals"].items():
191 if isinstance(proposal, JumpProposalCycle):
192 self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory(
193 proposal
194 )
195 elif isinstance(proposal, ProposalCycle):
196 pass
197 else:
198 raise TypeError("Unknown proposal type")
201def cpnest_proposal_factory(jump_proposal):
202 import cpnest.proposal
204 class CPNestEnsembleProposal(cpnest.proposal.EnsembleProposal):
205 def __init__(self, jp):
206 self.jump_proposal = jp
207 self.ensemble = None
209 def __call__(self, sample, **kwargs):
210 return self.get_sample(sample, **kwargs)
212 def get_sample(self, cpnest_sample, **kwargs):
213 sample = Sample.from_cpnest_live_point(cpnest_sample)
214 self.ensemble = kwargs.get("coordinates", self.ensemble)
215 sample = self.jump_proposal(sample=sample, sampler_name="cpnest", **kwargs)
216 self.log_J = self.jump_proposal.log_j
217 return self._update_cpnest_sample(cpnest_sample, sample)
219 @staticmethod
220 def _update_cpnest_sample(cpnest_sample, sample):
221 cpnest_sample.names = list(sample.keys())
222 for i, value in enumerate(sample.values()):
223 cpnest_sample.values[i] = value
224 return cpnest_sample
226 return CPNestEnsembleProposal(jump_proposal)
229def cpnest_proposal_cycle_factory(jump_proposals):
230 import cpnest.proposal
232 class CPNestProposalCycle(cpnest.proposal.ProposalCycle):
233 def __init__(self):
234 self.jump_proposals = copy.deepcopy(jump_proposals)
235 for i, prop in enumerate(self.jump_proposals.proposal_functions):
236 self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(
237 prop
238 )
239 self.jump_proposals.update_cycle()
240 super(CPNestProposalCycle, self).__init__(
241 proposals=self.jump_proposals.proposal_functions,
242 weights=self.jump_proposals.weights,
243 cyclelength=self.jump_proposals.cycle_length,
244 )
246 def get_sample(self, old, **kwargs):
247 return self.jump_proposals(sample=old, coordinates=self.ensemble, **kwargs)
249 def set_ensemble(self, ensemble):
250 self.ensemble = ensemble
252 return CPNestProposalCycle