Coverage for bilby/core/sampler/zeus.py: 37%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import os 

2import shutil 

3from shutil import copyfile 

4 

5import numpy as np 

6 

7from .base_sampler import SamplerError, signal_wrapper 

8from .emcee import Emcee 

9from .ptemcee import LikePriorEvaluator 

10 

11_evaluator = LikePriorEvaluator() 

12 

13 

14class Zeus(Emcee): 

15 """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/) 

16 

17 All positional and keyword arguments (i.e., the args and kwargs) passed to 

18 `run_sampler` will be propagated to `zeus.EnsembleSampler`, see 

19 documentation for that class for further help. Under Other Parameters, we 

20 list commonly used kwargs and the bilby defaults. 

21 

22 Parameters 

23 ========== 

24 nwalkers: int, (500) 

25 The number of walkers 

26 nsteps: int, (100) 

27 The number of steps 

28 nburn: int (None) 

29 If given, the fixed number of steps to discard as burn-in. These will 

30 be discarded from the total number of steps set by `nsteps` and 

31 therefore the value must be greater than `nsteps`. Else, nburn is 

32 estimated from the autocorrelation time 

33 burn_in_fraction: float, (0.25) 

34 The fraction of steps to discard as burn-in in the event that the 

35 autocorrelation time cannot be calculated 

36 burn_in_act: float 

37 The number of autocorrelation times to discard as burn-in 

38 

39 """ 

40 

41 sampler_name = "zeus" 

42 default_kwargs = dict( 

43 nwalkers=500, 

44 args=[], 

45 kwargs={}, 

46 pool=None, 

47 log_prob0=None, 

48 start=None, 

49 blobs0=None, 

50 iterations=100, 

51 thin=1, 

52 ) 

53 

54 def __init__( 

55 self, 

56 likelihood, 

57 priors, 

58 outdir="outdir", 

59 label="label", 

60 use_ratio=False, 

61 plot=False, 

62 skip_import_verification=False, 

63 pos0=None, 

64 nburn=None, 

65 burn_in_fraction=0.25, 

66 resume=True, 

67 burn_in_act=3, 

68 **kwargs, 

69 ): 

70 super(Zeus, self).__init__( 

71 likelihood=likelihood, 

72 priors=priors, 

73 outdir=outdir, 

74 label=label, 

75 use_ratio=use_ratio, 

76 plot=plot, 

77 skip_import_verification=skip_import_verification, 

78 pos0=pos0, 

79 nburn=nburn, 

80 burn_in_fraction=burn_in_fraction, 

81 resume=resume, 

82 burn_in_act=burn_in_act, 

83 **kwargs, 

84 ) 

85 

86 def _translate_kwargs(self, kwargs): 

87 super(Zeus, self)._translate_kwargs(kwargs=kwargs) 

88 

89 # check if using emcee-style arguments 

90 if "start" not in kwargs: 

91 if "rstate0" in kwargs: 

92 kwargs["start"] = kwargs.pop("rstate0") 

93 if "log_prob0" not in kwargs: 

94 if "lnprob0" in kwargs: 

95 kwargs["log_prob0"] = kwargs.pop("lnprob0") 

96 

97 @property 

98 def sampler_function_kwargs(self): 

99 keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"] 

100 

101 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} 

102 

103 return function_kwargs 

104 

105 @property 

106 def sampler_init_kwargs(self): 

107 init_kwargs = { 

108 key: value 

109 for key, value in self.kwargs.items() 

110 if key not in self.sampler_function_kwargs 

111 } 

112 

113 init_kwargs["logprob_fn"] = _evaluator.call_emcee 

114 init_kwargs["ndim"] = self.ndim 

115 

116 return init_kwargs 

117 

118 def write_current_state(self): 

119 self._sampler.distribute = map 

120 super(Zeus, self).write_current_state() 

121 self._sampler.distribute = getattr(self._sampler.pool, "map", map) 

122 

123 def _initialise_sampler(self): 

124 from zeus import EnsembleSampler 

125 

126 self._sampler = EnsembleSampler(**self.sampler_init_kwargs) 

127 self._init_chain_file() 

128 

129 def write_chains_to_file(self, sample): 

130 chain_file = self.checkpoint_info.chain_file 

131 temp_chain_file = chain_file + ".temp" 

132 if os.path.isfile(chain_file): 

133 copyfile(chain_file, temp_chain_file) 

134 

135 points = np.hstack([sample[0], np.array(sample[2])]) 

136 

137 with open(temp_chain_file, "a") as ff: 

138 for ii, point in enumerate(points): 

139 ff.write(self.checkpoint_info.chain_template.format(ii, *point)) 

140 shutil.move(temp_chain_file, chain_file) 

141 

142 def _set_pos0_for_resume(self): 

143 self.pos0 = self.sampler.get_last_sample() 

144 

145 @signal_wrapper 

146 def run_sampler(self): 

147 self._setup_pool() 

148 sampler_function_kwargs = self.sampler_function_kwargs 

149 iterations = sampler_function_kwargs.pop("iterations") 

150 iterations -= self._previous_iterations 

151 

152 sampler_function_kwargs["start"] = self.pos0 

153 

154 # main iteration loop 

155 for sample in self.sampler.sample( 

156 iterations=iterations, **sampler_function_kwargs 

157 ): 

158 self.write_chains_to_file(sample) 

159 self._close_pool() 

160 self.write_current_state() 

161 

162 self.result.sampler_output = np.nan 

163 self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) 

164 self.print_nburn_logging_info() 

165 

166 self._generate_result() 

167 

168 self.result.samples = self.sampler.get_chain(flat=True, discard=self.nburn) 

169 self.result.walkers = self.sampler.chain 

170 return self.result 

171 

172 def _generate_result(self): 

173 self.result.nburn = self.nburn 

174 self.calc_likelihood_count() 

175 if self.result.nburn > self.nsteps: 

176 raise SamplerError( 

177 "The run has finished, but the chain is not burned in: " 

178 f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})." 

179 " Try increasing the number of steps." 

180 ) 

181 blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape( 

182 (-1, 2) 

183 ) 

184 log_likelihoods, log_priors = blobs.T 

185 self.result.log_likelihood_evaluations = log_likelihoods 

186 self.result.log_prior_evaluations = log_priors 

187 self.result.log_evidence = np.nan 

188 self.result.log_evidence_err = np.nan