Coverage for bilby/core/sampler/nestle.py: 53%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1from pandas import DataFrame 

2 

3from .base_sampler import NestedSampler, signal_wrapper 

4 

5 

6class Nestle(NestedSampler): 

7 """bilby wrapper `nestle.Sampler` (http://kylebarbary.com/nestle/) 

8 

9 All positional and keyword arguments (i.e., the args and kwargs) passed to 

10 `run_sampler` will be propagated to `nestle.sample`, see documentation for 

11 that function for further help. Under Other Parameters, we list commonly 

12 used kwargs and the bilby defaults 

13 

14 Parameters 

15 ========== 

16 npoints: int 

17 The number of live points, note this can also equivalently be given as 

18 one of [nlive, nlives, n_live_points] 

19 method: {'classic', 'single', 'multi'} ('multi') 

20 Method used to select new points 

21 verbose: Bool 

22 If true, print information information about the convergence during 

23 sampling 

24 

25 """ 

26 

27 sampler_name = "nestle" 

28 default_kwargs = dict( 

29 verbose=True, 

30 method="multi", 

31 npoints=500, 

32 update_interval=None, 

33 npdim=None, 

34 maxiter=None, 

35 maxcall=None, 

36 dlogz=None, 

37 decline_factor=None, 

38 rstate=None, 

39 callback=None, 

40 steps=20, 

41 enlarge=1.2, 

42 ) 

43 

44 def _translate_kwargs(self, kwargs): 

45 kwargs = super()._translate_kwargs(kwargs) 

46 if "npoints" not in kwargs: 

47 for equiv in self.npoints_equiv_kwargs: 

48 if equiv in kwargs: 

49 kwargs["npoints"] = kwargs.pop(equiv) 

50 if "steps" not in kwargs: 

51 for equiv in self.walks_equiv_kwargs: 

52 if equiv in kwargs: 

53 kwargs["steps"] = kwargs.pop(equiv) 

54 

55 def _verify_kwargs_against_default_kwargs(self): 

56 if self.kwargs["verbose"]: 

57 import nestle 

58 

59 self.kwargs["callback"] = nestle.print_progress 

60 self.kwargs.pop("verbose") 

61 NestedSampler._verify_kwargs_against_default_kwargs(self) 

62 

63 @signal_wrapper 

64 def run_sampler(self): 

65 """Runs Nestle sampler with given kwargs and returns the result 

66 

67 Returns 

68 ======= 

69 bilby.core.result.Result: Packaged information about the result 

70 

71 """ 

72 import nestle 

73 

74 if nestle.__version__ == "0.2.0": 

75 # This is a very ugly hack to support numpy>=1.24 

76 nestle.np.float = float 

77 nestle.np.int = int 

78 

79 out = nestle.sample( 

80 loglikelihood=self.log_likelihood, 

81 prior_transform=self.prior_transform, 

82 ndim=self.ndim, 

83 **self.kwargs 

84 ) 

85 print("") 

86 

87 self.result.sampler_output = out 

88 self.result.samples = nestle.resample_equal(out.samples, out.weights) 

89 self.result.nested_samples = DataFrame( 

90 out.samples, columns=self.search_parameter_keys 

91 ) 

92 self.result.nested_samples["weights"] = out.weights 

93 self.result.nested_samples["log_likelihood"] = out.logl 

94 self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( 

95 unsorted_loglikelihoods=out.logl, 

96 unsorted_samples=out.samples, 

97 sorted_samples=self.result.samples, 

98 ) 

99 self.result.log_evidence = out.logz 

100 self.result.log_evidence_err = out.logzerr 

101 self.result.information_gain = out.h 

102 self.calc_likelihood_count() 

103 return self.result 

104 

105 def _run_test(self): 

106 """ 

107 Runs to test whether the sampler is properly running with the given 

108 kwargs without actually running to the end 

109 

110 Returns 

111 ======= 

112 bilby.core.result.Result: Dummy container for sampling results. 

113 

114 """ 

115 self.kwargs["maxiter"] = 2 

116 return self.run_sampler() 

117 

118 def write_current_state(self): 

119 """ 

120 Nestle doesn't support checkpointing so no current state will be 

121 written on interrupt. 

122 """ 

123 pass