Coverage for bilby/core/sampler/ptmcmc.py: 41%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import glob 

2import shutil 

3 

4import numpy as np 

5 

6from ..utils import logger 

7from .base_sampler import MCMCSampler, SamplerNotInstalledError, signal_wrapper 

8 

9 

10class PTMCMCSampler(MCMCSampler): 

11 """bilby wrapper of PTMCMC (https://github.com/jellis18/PTMCMCSampler/) 

12 

13 All positional and keyword arguments (i.e., the args and kwargs) passed to 

14 `run_sampler` will be propagated to `PTMCMCSampler.PTMCMCSampler`, see 

15 documentation for that class for further help. Under Other Parameters, we 

16 list commonly used kwargs and the bilby defaults. 

17 

18 Parameters 

19 ========== 

20 Niter: int (2*10**4 + 1) 

21 The number of mcmc steps 

22 burn: int (5 * 10**3) 

23 If given, the fixed number of steps to discard as burn-in 

24 thin: int (1) 

25 The number of steps before saving the sample to the chain 

26 custom_proposals: dict (None) 

27 Add dictionary of proposals to the array of proposals, this must be in 

28 the form of a dictionary with the name of the proposal, then a list 

29 containing the jump function and the weight e.g {'name' : [function , 

30 weight]} see 

31 (https://github.com/rgreen1995/PTMCMCSampler/blob/master/examples/simple.ipynb) 

32 and 

33 (http://jellis18.github.io/PTMCMCSampler/PTMCMCSampler.html#ptmcmcsampler-ptmcmcsampler-module) 

34 for examples and more info. 

35 logl_grad: func (None) 

36 Gradient of likelihood if known (default = None) 

37 logp_grad: func (None) 

38 Gradient of prior if known (default = None) 

39 verbose: bool (True) 

40 Update current run-status to the screen 

41 

42 """ 

43 

44 sampler_name = "ptmcmcsampler" 

45 abbreviation = "ptmcmc_temp" 

46 default_kwargs = { 

47 "p0": None, 

48 "Niter": 2 * 10**4 + 1, 

49 "neff": 10**4, 

50 "burn": 5 * 10**3, 

51 "verbose": True, 

52 "ladder": None, 

53 "Tmin": 1, 

54 "Tmax": None, 

55 "Tskip": 100, 

56 "isave": 1000, 

57 "thin": 1, 

58 "covUpdate": 1000, 

59 "SCAMweight": 1, 

60 "AMweight": 1, 

61 "DEweight": 1, 

62 "HMCweight": 0, 

63 "MALAweight": 0, 

64 "NUTSweight": 0, 

65 "HMCstepsize": 0.1, 

66 "HMCsteps": 300, 

67 "groups": None, 

68 "custom_proposals": None, 

69 "loglargs": {}, 

70 "loglkwargs": {}, 

71 "logpargs": {}, 

72 "logpkwargs": {}, 

73 "logl_grad": None, 

74 "logp_grad": None, 

75 "outDir": None, 

76 } 

77 hard_exit = True 

78 

79 def __init__( 

80 self, 

81 likelihood, 

82 priors, 

83 outdir="outdir", 

84 label="label", 

85 use_ratio=False, 

86 plot=False, 

87 skip_import_verification=False, 

88 **kwargs, 

89 ): 

90 

91 super(PTMCMCSampler, self).__init__( 

92 likelihood=likelihood, 

93 priors=priors, 

94 outdir=outdir, 

95 label=label, 

96 use_ratio=use_ratio, 

97 plot=plot, 

98 skip_import_verification=skip_import_verification, 

99 **kwargs, 

100 ) 

101 

102 if self.kwargs["p0"] is None: 

103 self.p0 = self.get_random_draw_from_prior() 

104 else: 

105 self.p0 = self.kwargs["p0"] 

106 self.likelihood = likelihood 

107 self.priors = priors 

108 

109 def _verify_external_sampler(self): 

110 # PTMCMC is imported with Caps so need to overwrite the parent function 

111 # which forces `__name__.lower() 

112 external_sampler_name = self.__class__.__name__ 

113 try: 

114 __import__(external_sampler_name) 

115 except (ImportError, SystemExit): 

116 raise SamplerNotInstalledError( 

117 f"Sampler {external_sampler_name} is not installed on this system" 

118 ) 

119 

120 def _translate_kwargs(self, kwargs): 

121 kwargs = super()._translate_kwargs(kwargs) 

122 if "Niter" not in kwargs: 

123 for equiv in self.nwalkers_equiv_kwargs: 

124 if equiv in kwargs: 

125 kwargs["Niter"] = kwargs.pop(equiv) 

126 if "burn" not in kwargs: 

127 for equiv in self.nburn_equiv_kwargs: 

128 if equiv in kwargs: 

129 kwargs["burn"] = kwargs.pop(equiv) 

130 

131 @property 

132 def custom_proposals(self): 

133 return self.kwargs["custom_proposals"] 

134 

135 @property 

136 def sampler_init_kwargs(self): 

137 keys = [ 

138 "groups", 

139 "loglargs", 

140 "logp_grad", 

141 "logpkwargs", 

142 "loglkwargs", 

143 "logl_grad", 

144 "logpargs", 

145 "outDir", 

146 "verbose", 

147 ] 

148 init_kwargs = {key: self.kwargs[key] for key in keys} 

149 if init_kwargs["outDir"] is None: 

150 init_kwargs["outDir"] = f"{self.outdir}/ptmcmc_temp_{self.label}/" 

151 return init_kwargs 

152 

153 @property 

154 def sampler_function_kwargs(self): 

155 keys = [ 

156 "Niter", 

157 "neff", 

158 "Tmin", 

159 "HMCweight", 

160 "covUpdate", 

161 "SCAMweight", 

162 "ladder", 

163 "burn", 

164 "NUTSweight", 

165 "AMweight", 

166 "MALAweight", 

167 "thin", 

168 "HMCstepsize", 

169 "isave", 

170 "Tskip", 

171 "HMCsteps", 

172 "Tmax", 

173 "DEweight", 

174 ] 

175 sampler_kwargs = {key: self.kwargs[key] for key in keys} 

176 return sampler_kwargs 

177 

178 @staticmethod 

179 def _import_external_sampler(): 

180 from PTMCMCSampler import PTMCMCSampler 

181 

182 return PTMCMCSampler 

183 

184 @signal_wrapper 

185 def run_sampler(self): 

186 PTMCMCSampler = self._import_external_sampler() 

187 sampler = PTMCMCSampler.PTSampler( 

188 ndim=self.ndim, 

189 logp=self.log_prior, 

190 logl=self.log_likelihood, 

191 cov=np.eye(self.ndim), 

192 **self.sampler_init_kwargs, 

193 ) 

194 if self.custom_proposals is not None: 

195 for proposal in self.custom_proposals: 

196 logger.info( 

197 f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}" 

198 ) 

199 sampler.addProposalToCycle( 

200 self.custom_proposals[proposal][0], 

201 self.custom_proposals[proposal][1], 

202 ) 

203 sampler.sample(p0=self.p0, **self.sampler_function_kwargs) 

204 samples, meta, loglike = self.__read_in_data() 

205 

206 self.calc_likelihood_count() 

207 self.result.nburn = self.sampler_function_kwargs["burn"] 

208 self.result.samples = samples[self.result.nburn :] 

209 self.meta_data["sampler_meta"] = meta 

210 self.result.log_likelihood_evaluations = loglike[self.result.nburn :] 

211 self.result.sampler_output = np.nan 

212 self.result.walkers = np.nan 

213 self.result.log_evidence = np.nan 

214 self.result.log_evidence_err = np.nan 

215 return self.result 

216 

217 def __read_in_data(self): 

218 """Read the data stored by PTMCMC to disk""" 

219 temp_outDir = self.sampler_init_kwargs["outDir"] 

220 try: 

221 data = np.loadtxt(f"{temp_outDir}chain_1.txt") 

222 except OSError: 

223 data = np.loadtxt(f"{temp_outDir}chain_1.0.txt") 

224 jumpfiles = glob.glob(f"{temp_outDir}/*jump.txt") 

225 jumps = map(np.loadtxt, jumpfiles) 

226 samples = data[:, :-4] 

227 loglike = data[:, -3] 

228 

229 jump_accept = {} 

230 for ct, j in enumerate(jumps): 

231 label = jumpfiles[ct].split("/")[-1].split("_jump.txt")[0] 

232 jump_accept[label] = j 

233 PT_swap = {"swap_accept": data[:, -1]} 

234 tot_accept = {"tot_accept": data[:, -2]} 

235 log_post = {"log_post": data[:, -4]} 

236 meta = {} 

237 meta["tot_accept"] = tot_accept 

238 meta["PT_swap"] = PT_swap 

239 meta["proposals"] = jump_accept 

240 meta["log_post"] = log_post 

241 

242 shutil.rmtree(temp_outDir) 

243 

244 return samples, meta, loglike 

245 

246 def write_current_state(self): 

247 """TODO: implement a checkpointing method""" 

248 pass