Coverage for bilby/core/sampler/kombine.py: 39%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import os 

2 

3import numpy as np 

4 

5from ..utils import logger 

6from .base_sampler import signal_wrapper 

7from .emcee import Emcee 

8from .ptemcee import LikePriorEvaluator 

9 

10_evaluator = LikePriorEvaluator() 

11 

12 

13class Kombine(Emcee): 

14 """bilby wrapper kombine (https://github.com/bfarr/kombine) 

15 

16 All positional and keyword arguments (i.e., the args and kwargs) passed to 

17 `run_sampler` will be propagated to `kombine.Sampler`, see 

18 documentation for that class for further help. Under Other Parameters, we 

19 list commonly used kwargs and the bilby defaults. 

20 

21 Parameters 

22 ========== 

23 nwalkers: int, (500) 

24 The number of walkers 

25 iterations: int, (100) 

26 The number of iterations 

27 auto_burnin: bool (False) 

28 Use `kombine`'s automatic burnin (at your own risk) 

29 nburn: int (None) 

30 If given, the fixed number of steps to discard as burn-in. These will 

31 be discarded from the total number of steps set by `nsteps` and 

32 therefore the value must be greater than `nsteps`. Else, nburn is 

33 estimated from the autocorrelation time 

34 burn_in_fraction: float, (0.25) 

35 The fraction of steps to discard as burn-in in the event that the 

36 autocorrelation time cannot be calculated 

37 burn_in_act: float (3.) 

38 The number of autocorrelation times to discard as burn-in 

39 

40 """ 

41 

42 sampler_name = "kombine" 

43 default_kwargs = dict( 

44 nwalkers=500, 

45 args=[], 

46 pool=None, 

47 transd=False, 

48 lnpost0=None, 

49 blob0=None, 

50 iterations=500, 

51 storechain=True, 

52 processes=1, 

53 update_interval=None, 

54 kde=None, 

55 kde_size=None, 

56 spaces=None, 

57 freeze_transd=False, 

58 test_steps=16, 

59 critical_pval=0.05, 

60 max_steps=None, 

61 burnin_verbose=False, 

62 ) 

63 

64 def __init__( 

65 self, 

66 likelihood, 

67 priors, 

68 outdir="outdir", 

69 label="label", 

70 use_ratio=False, 

71 plot=False, 

72 skip_import_verification=False, 

73 pos0=None, 

74 nburn=None, 

75 burn_in_fraction=0.25, 

76 resume=True, 

77 burn_in_act=3, 

78 autoburnin=False, 

79 **kwargs, 

80 ): 

81 super(Kombine, self).__init__( 

82 likelihood=likelihood, 

83 priors=priors, 

84 outdir=outdir, 

85 label=label, 

86 use_ratio=use_ratio, 

87 plot=plot, 

88 skip_import_verification=skip_import_verification, 

89 pos0=pos0, 

90 nburn=nburn, 

91 burn_in_fraction=burn_in_fraction, 

92 burn_in_act=burn_in_act, 

93 resume=resume, 

94 **kwargs, 

95 ) 

96 

97 if self.kwargs["nwalkers"] > self.kwargs["iterations"]: 

98 raise ValueError("Kombine Sampler requires Iterations be > nWalkers") 

99 self.autoburnin = autoburnin 

100 

101 def _check_version(self): 

102 # set prerelease to False to prevent checks for newer emcee versions in parent class 

103 self.prerelease = False 

104 

105 @property 

106 def sampler_function_kwargs(self): 

107 keys = [ 

108 "lnpost0", 

109 "blob0", 

110 "iterations", 

111 "storechain", 

112 "lnprop0", 

113 "update_interval", 

114 "kde", 

115 "kde_size", 

116 "spaces", 

117 "freeze_transd", 

118 ] 

119 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} 

120 function_kwargs["p0"] = self.pos0 

121 return function_kwargs 

122 

123 @property 

124 def sampler_burnin_kwargs(self): 

125 extra_keys = ["test_steps", "critical_pval", "max_steps", "burnin_verbose"] 

126 removal_keys = ["iterations", "spaces", "freeze_transd"] 

127 burnin_kwargs = self.sampler_function_kwargs.copy() 

128 for key in extra_keys: 

129 if key in self.kwargs: 

130 burnin_kwargs[key] = self.kwargs[key] 

131 if "burnin_verbose" in burnin_kwargs.keys(): 

132 burnin_kwargs["verbose"] = burnin_kwargs.pop("burnin_verbose") 

133 for key in removal_keys: 

134 if key in burnin_kwargs.keys(): 

135 burnin_kwargs.pop(key) 

136 return burnin_kwargs 

137 

138 @property 

139 def sampler_init_kwargs(self): 

140 init_kwargs = { 

141 key: value 

142 for key, value in self.kwargs.items() 

143 if key not in self.sampler_function_kwargs 

144 and key not in self.sampler_burnin_kwargs 

145 } 

146 init_kwargs.pop("burnin_verbose") 

147 init_kwargs["lnpostfn"] = _evaluator.call_emcee 

148 init_kwargs["ndim"] = self.ndim 

149 

150 return init_kwargs 

151 

152 def _initialise_sampler(self): 

153 import kombine 

154 

155 self._sampler = kombine.Sampler(**self.sampler_init_kwargs) 

156 self._init_chain_file() 

157 

158 def _set_pos0_for_resume(self): 

159 # take last iteration 

160 self.pos0 = self.sampler.chain[-1, :, :] 

161 

162 @property 

163 def sampler_chain(self): 

164 # remove last iterations when resuming 

165 nsteps = self._previous_iterations 

166 return self.sampler.chain[:nsteps, :, :] 

167 

168 def check_resume(self): 

169 return ( 

170 self.resume 

171 and os.path.isfile(self.checkpoint_info.sampler_file) 

172 and os.path.getsize(self.checkpoint_info.sampler_file) > 0 

173 ) 

174 

175 @signal_wrapper 

176 def run_sampler(self): 

177 self._setup_pool() 

178 if self.autoburnin: 

179 if self.check_resume(): 

180 logger.info("Resuming with autoburnin=True skips burnin process:") 

181 else: 

182 logger.info("Running kombine sampler's automatic burnin process") 

183 self.sampler.burnin(**self.sampler_burnin_kwargs) 

184 self.kwargs["iterations"] += self._previous_iterations 

185 self.nburn = self._previous_iterations 

186 logger.info( 

187 f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains" 

188 ) 

189 self._set_pos0_for_resume() 

190 

191 from tqdm.auto import tqdm 

192 

193 sampler_function_kwargs = self.sampler_function_kwargs 

194 iterations = sampler_function_kwargs.pop("iterations") 

195 iterations -= self._previous_iterations 

196 sampler_function_kwargs["p0"] = self.pos0 

197 for sample in tqdm( 

198 self.sampler.sample(iterations=iterations, **sampler_function_kwargs), 

199 total=iterations, 

200 ): 

201 self.write_chains_to_file(sample) 

202 self.write_current_state() 

203 self.result.sampler_output = np.nan 

204 if not self.autoburnin: 

205 tmp_chain = self.sampler.chain.copy() 

206 self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim))) 

207 self.print_nburn_logging_info() 

208 self._close_pool() 

209 

210 self._generate_result() 

211 self.result.log_evidence_err = np.nan 

212 

213 tmp_chain = self.sampler.chain[self.nburn :, :, :].copy() 

214 self.result.samples = tmp_chain.reshape((-1, self.ndim)) 

215 self.result.walkers = self.sampler.chain.reshape( 

216 (self.nwalkers, self.nsteps, self.ndim) 

217 ) 

218 return self.result 

219 

220 def _setup_pool(self): 

221 from kombine import SerialPool 

222 

223 super(Kombine, self)._setup_pool() 

224 if self.pool is None: 

225 self.pool = SerialPool() 

226 

227 def _close_pool(self): 

228 from kombine import SerialPool 

229 

230 if isinstance(self.pool, SerialPool): 

231 self.pool = None 

232 super(Kombine, self)._close_pool()