Coverage for bilby/core/sampler/kombine.py: 39%
98 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import os
3import numpy as np
5from ..utils import logger
6from .base_sampler import signal_wrapper
7from .emcee import Emcee
8from .ptemcee import LikePriorEvaluator
10_evaluator = LikePriorEvaluator()
13class Kombine(Emcee):
14 """bilby wrapper kombine (https://github.com/bfarr/kombine)
16 All positional and keyword arguments (i.e., the args and kwargs) passed to
17 `run_sampler` will be propagated to `kombine.Sampler`, see
18 documentation for that class for further help. Under Other Parameters, we
19 list commonly used kwargs and the bilby defaults.
21 Parameters
22 ==========
23 nwalkers: int, (500)
24 The number of walkers
25 iterations: int, (100)
26 The number of iterations
27 auto_burnin: bool (False)
28 Use `kombine`'s automatic burnin (at your own risk)
29 nburn: int (None)
30 If given, the fixed number of steps to discard as burn-in. These will
31 be discarded from the total number of steps set by `nsteps` and
32 therefore the value must be greater than `nsteps`. Else, nburn is
33 estimated from the autocorrelation time
34 burn_in_fraction: float, (0.25)
35 The fraction of steps to discard as burn-in in the event that the
36 autocorrelation time cannot be calculated
37 burn_in_act: float (3.)
38 The number of autocorrelation times to discard as burn-in
40 """
42 sampler_name = "kombine"
43 default_kwargs = dict(
44 nwalkers=500,
45 args=[],
46 pool=None,
47 transd=False,
48 lnpost0=None,
49 blob0=None,
50 iterations=500,
51 storechain=True,
52 processes=1,
53 update_interval=None,
54 kde=None,
55 kde_size=None,
56 spaces=None,
57 freeze_transd=False,
58 test_steps=16,
59 critical_pval=0.05,
60 max_steps=None,
61 burnin_verbose=False,
62 )
64 def __init__(
65 self,
66 likelihood,
67 priors,
68 outdir="outdir",
69 label="label",
70 use_ratio=False,
71 plot=False,
72 skip_import_verification=False,
73 pos0=None,
74 nburn=None,
75 burn_in_fraction=0.25,
76 resume=True,
77 burn_in_act=3,
78 autoburnin=False,
79 **kwargs,
80 ):
81 super(Kombine, self).__init__(
82 likelihood=likelihood,
83 priors=priors,
84 outdir=outdir,
85 label=label,
86 use_ratio=use_ratio,
87 plot=plot,
88 skip_import_verification=skip_import_verification,
89 pos0=pos0,
90 nburn=nburn,
91 burn_in_fraction=burn_in_fraction,
92 burn_in_act=burn_in_act,
93 resume=resume,
94 **kwargs,
95 )
97 if self.kwargs["nwalkers"] > self.kwargs["iterations"]:
98 raise ValueError("Kombine Sampler requires Iterations be > nWalkers")
99 self.autoburnin = autoburnin
101 def _check_version(self):
102 # set prerelease to False to prevent checks for newer emcee versions in parent class
103 self.prerelease = False
105 @property
106 def sampler_function_kwargs(self):
107 keys = [
108 "lnpost0",
109 "blob0",
110 "iterations",
111 "storechain",
112 "lnprop0",
113 "update_interval",
114 "kde",
115 "kde_size",
116 "spaces",
117 "freeze_transd",
118 ]
119 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
120 function_kwargs["p0"] = self.pos0
121 return function_kwargs
123 @property
124 def sampler_burnin_kwargs(self):
125 extra_keys = ["test_steps", "critical_pval", "max_steps", "burnin_verbose"]
126 removal_keys = ["iterations", "spaces", "freeze_transd"]
127 burnin_kwargs = self.sampler_function_kwargs.copy()
128 for key in extra_keys:
129 if key in self.kwargs:
130 burnin_kwargs[key] = self.kwargs[key]
131 if "burnin_verbose" in burnin_kwargs.keys():
132 burnin_kwargs["verbose"] = burnin_kwargs.pop("burnin_verbose")
133 for key in removal_keys:
134 if key in burnin_kwargs.keys():
135 burnin_kwargs.pop(key)
136 return burnin_kwargs
138 @property
139 def sampler_init_kwargs(self):
140 init_kwargs = {
141 key: value
142 for key, value in self.kwargs.items()
143 if key not in self.sampler_function_kwargs
144 and key not in self.sampler_burnin_kwargs
145 }
146 init_kwargs.pop("burnin_verbose")
147 init_kwargs["lnpostfn"] = _evaluator.call_emcee
148 init_kwargs["ndim"] = self.ndim
150 return init_kwargs
152 def _initialise_sampler(self):
153 import kombine
155 self._sampler = kombine.Sampler(**self.sampler_init_kwargs)
156 self._init_chain_file()
158 def _set_pos0_for_resume(self):
159 # take last iteration
160 self.pos0 = self.sampler.chain[-1, :, :]
162 @property
163 def sampler_chain(self):
164 # remove last iterations when resuming
165 nsteps = self._previous_iterations
166 return self.sampler.chain[:nsteps, :, :]
168 def check_resume(self):
169 return (
170 self.resume
171 and os.path.isfile(self.checkpoint_info.sampler_file)
172 and os.path.getsize(self.checkpoint_info.sampler_file) > 0
173 )
175 @signal_wrapper
176 def run_sampler(self):
177 self._setup_pool()
178 if self.autoburnin:
179 if self.check_resume():
180 logger.info("Resuming with autoburnin=True skips burnin process:")
181 else:
182 logger.info("Running kombine sampler's automatic burnin process")
183 self.sampler.burnin(**self.sampler_burnin_kwargs)
184 self.kwargs["iterations"] += self._previous_iterations
185 self.nburn = self._previous_iterations
186 logger.info(
187 f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains"
188 )
189 self._set_pos0_for_resume()
191 from tqdm.auto import tqdm
193 sampler_function_kwargs = self.sampler_function_kwargs
194 iterations = sampler_function_kwargs.pop("iterations")
195 iterations -= self._previous_iterations
196 sampler_function_kwargs["p0"] = self.pos0
197 for sample in tqdm(
198 self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
199 total=iterations,
200 ):
201 self.write_chains_to_file(sample)
202 self.write_current_state()
203 self.result.sampler_output = np.nan
204 if not self.autoburnin:
205 tmp_chain = self.sampler.chain.copy()
206 self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim)))
207 self.print_nburn_logging_info()
208 self._close_pool()
210 self._generate_result()
211 self.result.log_evidence_err = np.nan
213 tmp_chain = self.sampler.chain[self.nburn :, :, :].copy()
214 self.result.samples = tmp_chain.reshape((-1, self.ndim))
215 self.result.walkers = self.sampler.chain.reshape(
216 (self.nwalkers, self.nsteps, self.ndim)
217 )
218 return self.result
220 def _setup_pool(self):
221 from kombine import SerialPool
223 super(Kombine, self)._setup_pool()
224 if self.pool is None:
225 self.pool = SerialPool()
227 def _close_pool(self):
228 from kombine import SerialPool
230 if isinstance(self.pool, SerialPool):
231 self.pool = None
232 super(Kombine, self)._close_pool()