Coverage for bilby/core/sampler/ptmcmc.py: 41%
98 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import glob
2import shutil
4import numpy as np
6from ..utils import logger
7from .base_sampler import MCMCSampler, SamplerNotInstalledError, signal_wrapper
10class PTMCMCSampler(MCMCSampler):
11 """bilby wrapper of PTMCMC (https://github.com/jellis18/PTMCMCSampler/)
13 All positional and keyword arguments (i.e., the args and kwargs) passed to
14 `run_sampler` will be propagated to `PTMCMCSampler.PTMCMCSampler`, see
15 documentation for that class for further help. Under Other Parameters, we
16 list commonly used kwargs and the bilby defaults.
18 Parameters
19 ==========
20 Niter: int (2*10**4 + 1)
21 The number of mcmc steps
22 burn: int (5 * 10**3)
23 If given, the fixed number of steps to discard as burn-in
24 thin: int (1)
25 The number of steps before saving the sample to the chain
26 custom_proposals: dict (None)
27 Add dictionary of proposals to the array of proposals, this must be in
28 the form of a dictionary with the name of the proposal, then a list
29 containing the jump function and the weight e.g {'name' : [function ,
30 weight]} see
31 (https://github.com/rgreen1995/PTMCMCSampler/blob/master/examples/simple.ipynb)
32 and
33 (http://jellis18.github.io/PTMCMCSampler/PTMCMCSampler.html#ptmcmcsampler-ptmcmcsampler-module)
34 for examples and more info.
35 logl_grad: func (None)
36 Gradient of likelihood if known (default = None)
37 logp_grad: func (None)
38 Gradient of prior if known (default = None)
39 verbose: bool (True)
40 Update current run-status to the screen
42 """
44 sampler_name = "ptmcmcsampler"
45 abbreviation = "ptmcmc_temp"
46 default_kwargs = {
47 "p0": None,
48 "Niter": 2 * 10**4 + 1,
49 "neff": 10**4,
50 "burn": 5 * 10**3,
51 "verbose": True,
52 "ladder": None,
53 "Tmin": 1,
54 "Tmax": None,
55 "Tskip": 100,
56 "isave": 1000,
57 "thin": 1,
58 "covUpdate": 1000,
59 "SCAMweight": 1,
60 "AMweight": 1,
61 "DEweight": 1,
62 "HMCweight": 0,
63 "MALAweight": 0,
64 "NUTSweight": 0,
65 "HMCstepsize": 0.1,
66 "HMCsteps": 300,
67 "groups": None,
68 "custom_proposals": None,
69 "loglargs": {},
70 "loglkwargs": {},
71 "logpargs": {},
72 "logpkwargs": {},
73 "logl_grad": None,
74 "logp_grad": None,
75 "outDir": None,
76 }
77 hard_exit = True
79 def __init__(
80 self,
81 likelihood,
82 priors,
83 outdir="outdir",
84 label="label",
85 use_ratio=False,
86 plot=False,
87 skip_import_verification=False,
88 **kwargs,
89 ):
91 super(PTMCMCSampler, self).__init__(
92 likelihood=likelihood,
93 priors=priors,
94 outdir=outdir,
95 label=label,
96 use_ratio=use_ratio,
97 plot=plot,
98 skip_import_verification=skip_import_verification,
99 **kwargs,
100 )
102 if self.kwargs["p0"] is None:
103 self.p0 = self.get_random_draw_from_prior()
104 else:
105 self.p0 = self.kwargs["p0"]
106 self.likelihood = likelihood
107 self.priors = priors
109 def _verify_external_sampler(self):
110 # PTMCMC is imported with Caps so need to overwrite the parent function
111 # which forces `__name__.lower()
112 external_sampler_name = self.__class__.__name__
113 try:
114 __import__(external_sampler_name)
115 except (ImportError, SystemExit):
116 raise SamplerNotInstalledError(
117 f"Sampler {external_sampler_name} is not installed on this system"
118 )
120 def _translate_kwargs(self, kwargs):
121 kwargs = super()._translate_kwargs(kwargs)
122 if "Niter" not in kwargs:
123 for equiv in self.nwalkers_equiv_kwargs:
124 if equiv in kwargs:
125 kwargs["Niter"] = kwargs.pop(equiv)
126 if "burn" not in kwargs:
127 for equiv in self.nburn_equiv_kwargs:
128 if equiv in kwargs:
129 kwargs["burn"] = kwargs.pop(equiv)
131 @property
132 def custom_proposals(self):
133 return self.kwargs["custom_proposals"]
135 @property
136 def sampler_init_kwargs(self):
137 keys = [
138 "groups",
139 "loglargs",
140 "logp_grad",
141 "logpkwargs",
142 "loglkwargs",
143 "logl_grad",
144 "logpargs",
145 "outDir",
146 "verbose",
147 ]
148 init_kwargs = {key: self.kwargs[key] for key in keys}
149 if init_kwargs["outDir"] is None:
150 init_kwargs["outDir"] = f"{self.outdir}/ptmcmc_temp_{self.label}/"
151 return init_kwargs
153 @property
154 def sampler_function_kwargs(self):
155 keys = [
156 "Niter",
157 "neff",
158 "Tmin",
159 "HMCweight",
160 "covUpdate",
161 "SCAMweight",
162 "ladder",
163 "burn",
164 "NUTSweight",
165 "AMweight",
166 "MALAweight",
167 "thin",
168 "HMCstepsize",
169 "isave",
170 "Tskip",
171 "HMCsteps",
172 "Tmax",
173 "DEweight",
174 ]
175 sampler_kwargs = {key: self.kwargs[key] for key in keys}
176 return sampler_kwargs
178 @staticmethod
179 def _import_external_sampler():
180 from PTMCMCSampler import PTMCMCSampler
182 return PTMCMCSampler
184 @signal_wrapper
185 def run_sampler(self):
186 PTMCMCSampler = self._import_external_sampler()
187 sampler = PTMCMCSampler.PTSampler(
188 ndim=self.ndim,
189 logp=self.log_prior,
190 logl=self.log_likelihood,
191 cov=np.eye(self.ndim),
192 **self.sampler_init_kwargs,
193 )
194 if self.custom_proposals is not None:
195 for proposal in self.custom_proposals:
196 logger.info(
197 f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}"
198 )
199 sampler.addProposalToCycle(
200 self.custom_proposals[proposal][0],
201 self.custom_proposals[proposal][1],
202 )
203 sampler.sample(p0=self.p0, **self.sampler_function_kwargs)
204 samples, meta, loglike = self.__read_in_data()
206 self.calc_likelihood_count()
207 self.result.nburn = self.sampler_function_kwargs["burn"]
208 self.result.samples = samples[self.result.nburn :]
209 self.meta_data["sampler_meta"] = meta
210 self.result.log_likelihood_evaluations = loglike[self.result.nburn :]
211 self.result.sampler_output = np.nan
212 self.result.walkers = np.nan
213 self.result.log_evidence = np.nan
214 self.result.log_evidence_err = np.nan
215 return self.result
217 def __read_in_data(self):
218 """Read the data stored by PTMCMC to disk"""
219 temp_outDir = self.sampler_init_kwargs["outDir"]
220 try:
221 data = np.loadtxt(f"{temp_outDir}chain_1.txt")
222 except OSError:
223 data = np.loadtxt(f"{temp_outDir}chain_1.0.txt")
224 jumpfiles = glob.glob(f"{temp_outDir}/*jump.txt")
225 jumps = map(np.loadtxt, jumpfiles)
226 samples = data[:, :-4]
227 loglike = data[:, -3]
229 jump_accept = {}
230 for ct, j in enumerate(jumps):
231 label = jumpfiles[ct].split("/")[-1].split("_jump.txt")[0]
232 jump_accept[label] = j
233 PT_swap = {"swap_accept": data[:, -1]}
234 tot_accept = {"tot_accept": data[:, -2]}
235 log_post = {"log_post": data[:, -4]}
236 meta = {}
237 meta["tot_accept"] = tot_accept
238 meta["PT_swap"] = PT_swap
239 meta["proposals"] = jump_accept
240 meta["log_post"] = log_post
242 shutil.rmtree(temp_outDir)
244 return samples, meta, loglike
246 def write_current_state(self):
247 """TODO: implement a checkpointing method"""
248 pass