Coverage for bilby/bilby_mcmc/proposals.py: 74%

635 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import importlib 

2import time 

3from abc import ABCMeta, abstractmethod 

4 

5import numpy as np 

6from scipy.spatial.distance import jensenshannon 

7from scipy.stats import gaussian_kde 

8 

9from ..core.fisher import FisherMatrixPosteriorEstimator 

10from ..core.prior import PriorDict 

11from ..core.sampler.base_sampler import SamplerError 

12from ..core.utils import logger, random, reflect 

13from ..gw.source import PARAMETER_SETS 

14 

15 

16class ProposalCycle(object): 

17 def __init__(self, proposal_list): 

18 self.proposal_list = proposal_list 

19 self.weights = [prop.weight for prop in self.proposal_list] 

20 self.normalized_weights = [w / sum(self.weights) for w in self.weights] 

21 self.weighted_proposal_list = [ 

22 random.rng.choice(self.proposal_list, p=self.normalized_weights) 

23 for _ in range(10 * int(1 / min(self.normalized_weights))) 

24 ] 

25 self.nproposals = len(self.weighted_proposal_list) 

26 self._position = 0 

27 

28 @property 

29 def position(self): 

30 return self._position 

31 

32 @position.setter 

33 def position(self, position): 

34 self._position = np.mod(position, self.nproposals) 

35 

36 def get_proposal(self): 

37 prop = self.weighted_proposal_list[self._position] 

38 self.position += 1 

39 return prop 

40 

41 def __str__(self): 

42 string = "ProposalCycle:\n" 

43 for prop in self.proposal_list: 

44 string += f" {prop}\n" 

45 return string 

46 

47 

48class BaseProposal(object): 

49 _accepted = 0 

50 _rejected = 0 

51 __metaclass__ = ABCMeta 

52 

53 def __init__(self, priors, weight=1, subset=None): 

54 self._str_attrs = ["acceptance_ratio", "n"] 

55 

56 self.parameters = priors.non_fixed_keys 

57 self.weight = weight 

58 self.subset = subset 

59 

60 # Restrict to a subset 

61 if self.subset is not None: 

62 self.parameters = [p for p in self.parameters if p in subset] 

63 self._str_attrs.append("parameters") 

64 

65 if len(self.parameters) == 0: 

66 raise ValueError("Proposal requested with zero parameters") 

67 

68 self.ndim = len(self.parameters) 

69 

70 self.prior_boundary_dict = {key: priors[key].boundary for key in priors} 

71 self.prior_minimum_dict = {key: np.max(priors[key].minimum) for key in priors} 

72 self.prior_maximum_dict = {key: np.min(priors[key].maximum) for key in priors} 

73 self.prior_width_dict = {key: np.max(priors[key].width) for key in priors} 

74 

75 @property 

76 def accepted(self): 

77 return self._accepted 

78 

79 @accepted.setter 

80 def accepted(self, accepted): 

81 self._accepted = accepted 

82 

83 @property 

84 def rejected(self): 

85 return self._rejected 

86 

87 @rejected.setter 

88 def rejected(self, rejected): 

89 self._rejected = rejected 

90 

91 @property 

92 def acceptance_ratio(self): 

93 if self.n == 0: 

94 return np.nan 

95 else: 

96 return self.accepted / self.n 

97 

98 @property 

99 def n(self): 

100 return self.accepted + self.rejected 

101 

102 def __str__(self): 

103 msg = [f"{type(self).__name__}("] 

104 for attr in self._str_attrs: 

105 val = getattr(self, attr, "N/A") 

106 if isinstance(val, (float, int)): 

107 val = f"{val:1.2g}" 

108 msg.append(f"{attr}:{val},") 

109 return "".join(msg) + ")" 

110 

111 def apply_boundaries(self, point): 

112 for key in self.parameters: 

113 boundary = self.prior_boundary_dict[key] 

114 if boundary is None: 

115 continue 

116 elif boundary == "periodic": 

117 point[key] = self.apply_periodic_boundary(key, point[key]) 

118 elif boundary == "reflective": 

119 point[key] = self.apply_reflective_boundary(key, point[key]) 

120 else: 

121 raise SamplerError(f"Boundary {boundary} not implemented") 

122 return point 

123 

124 def apply_periodic_boundary(self, key, val): 

125 minimum = self.prior_minimum_dict[key] 

126 width = self.prior_width_dict[key] 

127 return minimum + np.mod(val - minimum, width) 

128 

129 def apply_reflective_boundary(self, key, val): 

130 minimum = self.prior_minimum_dict[key] 

131 width = self.prior_width_dict[key] 

132 val_normalised = (val - minimum) / width 

133 val_normalised_reflected = reflect(np.array(val_normalised)) 

134 return minimum + width * val_normalised_reflected 

135 

136 def __call__(self, chain, likelihood=None, priors=None): 

137 

138 if getattr(self, "needs_likelihood_and_priors", False): 

139 sample, log_factor = self.propose(chain, likelihood, priors) 

140 else: 

141 sample, log_factor = self.propose(chain) 

142 

143 if log_factor == 0: 

144 sample = self.apply_boundaries(sample) 

145 

146 return sample, log_factor 

147 

148 @abstractmethod 

149 def propose(self, chain): 

150 """Propose a new point 

151 

152 This method must be overwritten by implemented proposals. The propose 

153 method is called by __call__, then boundaries applied, before returning 

154 the proposed point. 

155 

156 Parameters 

157 ---------- 

158 chain: bilby.core.sampler.bilby_mcmc.chain.Chain 

159 The chain to use for the proposal 

160 

161 Returns 

162 ------- 

163 proposal: bilby.core.sampler.bilby_mcmc.Sample 

164 The proposed point 

165 log_factor: float 

166 The natural-log of the additional factor entering the acceptance 

167 probability to ensure detailed balance. For symmetric proposals, 

168 a value of 0 should be returned. 

169 """ 

170 pass 

171 

172 @staticmethod 

173 def check_dependencies(warn=True): 

174 """Check the dependencies required to use the proposal 

175 

176 Parameters 

177 ---------- 

178 warn: bool 

179 If true, print a warning 

180 

181 Returns 

182 ------- 

183 check: bool 

184 If true, dependencies exist 

185 """ 

186 return True 

187 

188 

189class FixedGaussianProposal(BaseProposal): 

190 """A proposal using a fixed non-correlated Gaussian distribution 

191 

192 Parameters 

193 ---------- 

194 priors: bilby.core.prior.PriorDict 

195 The set of priors 

196 weight: float 

197 Weighting factor 

198 subset: list 

199 A list of keys for which to restrict the proposal to (other parameters 

200 will be kept fixed) 

201 sigma: float 

202 The scaling factor for proposals 

203 """ 

204 

205 def __init__(self, priors, weight=1, subset=None, sigma=0.01): 

206 super(FixedGaussianProposal, self).__init__(priors, weight, subset) 

207 self.sigmas = {} 

208 for key in self.parameters: 

209 if np.isinf(self.prior_width_dict[key]): 

210 self.prior_width_dict[key] = 1 

211 if isinstance(sigma, float): 

212 self.sigmas[key] = sigma 

213 elif isinstance(sigma, dict): 

214 self.sigmas[key] = sigma[key] 

215 else: 

216 raise SamplerError("FixedGaussianProposal sigma not understood") 

217 

218 def propose(self, chain): 

219 sample = chain.current_sample 

220 for key in self.parameters: 

221 sigma = self.prior_width_dict[key] * self.sigmas[key] 

222 sample[key] += sigma * random.rng.normal(0, 1) 

223 log_factor = 0 

224 return sample, log_factor 

225 

226 

227class AdaptiveGaussianProposal(BaseProposal): 

228 def __init__( 

229 self, 

230 priors, 

231 weight=1, 

232 subset=None, 

233 sigma=1, 

234 scale_init=1e0, 

235 stop=1e5, 

236 target_facc=0.234, 

237 ): 

238 super(AdaptiveGaussianProposal, self).__init__(priors, weight, subset) 

239 self.sigmas = {} 

240 for key in self.parameters: 

241 if np.isinf(self.prior_width_dict[key]): 

242 self.prior_width_dict[key] = 1 

243 if isinstance(sigma, (float, int)): 

244 self.sigmas[key] = sigma 

245 elif isinstance(sigma, dict): 

246 self.sigmas[key] = sigma[key] 

247 else: 

248 raise SamplerError("AdaptiveGaussianProposal sigma not understood") 

249 

250 self.target_facc = target_facc 

251 self.scale = scale_init 

252 self.stop = stop 

253 self._str_attrs.append("scale") 

254 self._last_accepted = 0 

255 

256 def propose(self, chain): 

257 sample = chain.current_sample 

258 self.update_scale(chain) 

259 if random.rng.uniform(0, 1) < 1e-3: 

260 factor = 1e1 

261 elif random.rng.uniform(0, 1) < 1e-4: 

262 factor = 1e2 

263 else: 

264 factor = 1 

265 for key in self.parameters: 

266 sigma = factor * self.scale * self.prior_width_dict[key] * self.sigmas[key] 

267 sample[key] += sigma * random.rng.normal(0, 1) 

268 log_factor = 0 

269 return sample, log_factor 

270 

271 def update_scale(self, chain): 

272 """ 

273 The adaptation of the scale follows (35)/(36) of https://arxiv.org/abs/1409.7215 

274 """ 

275 if 0 < self.n < self.stop: 

276 s_gamma = (self.stop / self.n) ** 0.2 - 1 

277 if self.accepted > self._last_accepted: 

278 self.scale += s_gamma * (1 - self.target_facc) / 100 

279 else: 

280 self.scale -= s_gamma * self.target_facc / 100 

281 self._last_accepted = self.accepted 

282 self.scale = max(self.scale, 1 / self.stop) 

283 

284 

285class DifferentialEvolutionProposal(BaseProposal): 

286 """A proposal using Differential Evolution 

287 

288 Parameters 

289 ---------- 

290 priors: bilby.core.prior.PriorDict 

291 The set of priors 

292 weight: float 

293 Weighting factor 

294 subset: list 

295 A list of keys for which to restrict the proposal to (other parameters 

296 will be kept fixed) 

297 mode_hopping_frac: float 

298 The fraction of proposals which use 'mode hopping' 

299 """ 

300 

301 def __init__(self, priors, weight=1, subset=None, mode_hopping_frac=0.5): 

302 super(DifferentialEvolutionProposal, self).__init__(priors, weight, subset) 

303 self.mode_hopping_frac = mode_hopping_frac 

304 

305 def propose(self, chain): 

306 theta = chain.current_sample 

307 theta1 = chain.random_sample 

308 theta2 = chain.random_sample 

309 if random.rng.uniform(0, 1) > self.mode_hopping_frac: 

310 gamma = 1 

311 else: 

312 # Base jump size 

313 gamma = random.rng.normal(0, 2.38 / np.sqrt(2 * self.ndim)) 

314 # Scale uniformly in log between 0.1 and 10 times 

315 gamma *= np.exp(np.log(0.1) + np.log(100.0) * random.rng.uniform(0, 1)) 

316 

317 for key in self.parameters: 

318 theta[key] += gamma * (theta2[key] - theta1[key]) 

319 

320 log_factor = 0 

321 return theta, log_factor 

322 

323 

324class UniformProposal(BaseProposal): 

325 """A proposal using uniform draws from the prior support 

326 

327 Note: for priors with infinite support, this proposal will not propose a 

328 point, leading to inefficient sampling. You may wish to omit this proposal 

329 if you have priors with infinite support. 

330 

331 Parameters 

332 ---------- 

333 priors: bilby.core.prior.PriorDict 

334 The set of priors 

335 weight: float 

336 Weighting factor 

337 subset: list 

338 A list of keys for which to restrict the proposal to (other parameters 

339 will be kept fixed) 

340 """ 

341 

342 def __init__(self, priors, weight=1, subset=None): 

343 super(UniformProposal, self).__init__(priors, weight, subset) 

344 

345 def propose(self, chain): 

346 sample = chain.current_sample 

347 for key in self.parameters: 

348 width = self.prior_width_dict[key] 

349 if np.isinf(width) is False: 

350 sample[key] = random.rng.uniform( 

351 self.prior_minimum_dict[key], self.prior_maximum_dict[key] 

352 ) 

353 else: 

354 # Unable to generate a uniform sample on infinite support 

355 pass 

356 log_factor = 0 

357 return sample, log_factor 

358 

359 

360class PriorProposal(BaseProposal): 

361 """A proposal using draws from the prior distribution 

362 

363 Note: for priors which use interpolation, this proposal can be problematic 

364 as the proposal gets pickled in multiprocessing. Either, use serial 

365 processing (npool=1) or fall back to a UniformProposal. 

366 

367 Parameters 

368 ---------- 

369 priors: bilby.core.prior.PriorDict 

370 The set of priors 

371 weight: float 

372 Weighting factor 

373 subset: list 

374 A list of keys for which to restrict the proposal to (other parameters 

375 will be kept fixed) 

376 """ 

377 

378 def __init__(self, priors, weight=1, subset=None): 

379 super(PriorProposal, self).__init__(priors, weight, subset) 

380 self.priors = PriorDict({key: priors[key] for key in self.parameters}) 

381 

382 def propose(self, chain): 

383 sample = chain.current_sample 

384 lnp_theta = self.priors.ln_prob(sample.as_dict(self.parameters)) 

385 prior_sample = self.priors.sample() 

386 for key in self.parameters: 

387 sample[key] = prior_sample[key] 

388 lnp_thetaprime = self.priors.ln_prob(sample.as_dict(self.parameters)) 

389 log_factor = lnp_theta - lnp_thetaprime 

390 return sample, log_factor 

391 

392 

393_density_estimate_doc = """ A proposal using draws from a {estimator} fit to the chain 

394 

395Parameters 

396---------- 

397priors: bilby.core.prior.PriorDict 

398 The set of priors 

399weight: float 

400 Weighting factor 

401subset: list 

402 A list of keys for which to restrict the proposal to (other parameters 

403 will be kept fixed) 

404first_fit: int 

405 The number of steps to take before first fitting the KDE 

406fit_multiplier: int 

407 The multiplier for the next fit 

408nsamples_for_density: int 

409 The number of samples to use when fitting the KDE 

410fallback: bilby.core.sampler.bilby_mcmc.proposal.BaseProposal 

411 A proposal to use before first training 

412scale_fits: int 

413 A scaling factor for both the initial and subsequent updates 

414""" 

415 

416 

417class DensityEstimateProposal(BaseProposal): 

418 def __init__( 

419 self, 

420 priors, 

421 weight=1, 

422 subset=None, 

423 first_fit=1000, 

424 fit_multiplier=10, 

425 nsamples_for_density=1000, 

426 fallback=AdaptiveGaussianProposal, 

427 scale_fits=1, 

428 ): 

429 super(DensityEstimateProposal, self).__init__(priors, weight, subset) 

430 self.nsamples_for_density = nsamples_for_density 

431 self.fallback = fallback(priors, weight, subset) 

432 self.fit_multiplier = fit_multiplier * scale_fits 

433 

434 # Counters 

435 self.steps_since_refit = 0 

436 self.next_refit_time = first_fit * scale_fits 

437 self.density = None 

438 self.trained = False 

439 self._str_attrs.append("trained") 

440 

441 density_name = None 

442 __doc__ = _density_estimate_doc.format(estimator=density_name) 

443 

444 def _fit(self, dataset): 

445 raise NotImplementedError 

446 

447 def _evaluate(self, point): 

448 raise NotImplementedError 

449 

450 def _sample(self, nsamples=None): 

451 raise NotImplementedError 

452 

453 def refit(self, chain): 

454 current_density = self.density 

455 start = time.time() 

456 

457 # Draw two (possibly overlapping) data sets for training and verification 

458 dataset = [] 

459 verification_dataset = [] 

460 nsamples_for_density = min(chain.position, self.nsamples_for_density) 

461 for _ in range(nsamples_for_density): 

462 s = chain.random_sample 

463 dataset.append([s[key] for key in self.parameters]) 

464 s = chain.random_sample 

465 verification_dataset.append([s[key] for key in self.parameters]) 

466 

467 # Fit the density 

468 self.density = self._fit(np.array(dataset).T) 

469 

470 # Print a log message 

471 took = time.time() - start 

472 logger.debug( 

473 f"{self.density_name} construction at {self.steps_since_refit} finished" 

474 f" for length {chain.position} chain, took {took:0.2f}s." 

475 f" Current accept-ratio={self.acceptance_ratio:0.2f}" 

476 ) 

477 

478 # Reset counters for next training 

479 self.steps_since_refit = 0 

480 self.next_refit_time *= self.fit_multiplier 

481 

482 # Verify training hasn't overconstrained 

483 new_draws = np.atleast_2d(self._sample(1000)) 

484 verification_dataset = np.array(verification_dataset) 

485 fail_parameters = [] 

486 for ii, key in enumerate(self.parameters): 

487 std_draws = np.std(new_draws[:, ii]) 

488 std_verification = np.std(verification_dataset[:, ii]) 

489 if std_draws < 0.1 * std_verification: 

490 fail_parameters.append(key) 

491 

492 if len(fail_parameters) > 0: 

493 logger.debug( 

494 f"{self.density_name} construction failed verification and is discarded" 

495 ) 

496 self.density = current_density 

497 else: 

498 self.trained = True 

499 

500 def propose(self, chain): 

501 self.steps_since_refit += 1 

502 

503 # Check if we refit 

504 testA = self.steps_since_refit >= self.next_refit_time 

505 if testA: 

506 try: 

507 self.refit(chain) 

508 except Exception as e: 

509 logger.warning(f"Failed to refit chain due to error {e}") 

510 

511 # If KDE is yet to be fitted, use the fallback 

512 if self.trained is False: 

513 return self.fallback.propose(chain) 

514 

515 # Grab the current sample and it's probability under the KDE 

516 theta = chain.current_sample 

517 ln_p_theta = self._evaluate(list(theta.as_dict(self.parameters).values())) 

518 

519 # Sample and update theta 

520 new_sample = self._sample(1) 

521 for key, val in zip(self.parameters, new_sample): 

522 theta[key] = val 

523 

524 # Calculate the probability of the new sample and the KDE 

525 ln_p_thetaprime = self._evaluate(list(theta.as_dict(self.parameters).values())) 

526 

527 # Calculate Q(theta|theta') / Q(theta'|theta) 

528 log_factor = ln_p_theta - ln_p_thetaprime 

529 

530 return theta, log_factor 

531 

532 

533class KDEProposal(DensityEstimateProposal): 

534 density_name = "Gaussian KDE" 

535 __doc__ = _density_estimate_doc.format(estimator=density_name) 

536 

537 def _fit(self, dataset): 

538 return gaussian_kde(dataset) 

539 

540 def _evaluate(self, point): 

541 return self.density.logpdf(point)[0] 

542 

543 def _sample(self, nsamples=None): 

544 return np.atleast_1d(np.squeeze(self.density.resample(nsamples))) 

545 

546 

547class GMMProposal(DensityEstimateProposal): 

548 density_name = "Gaussian Mixture Model" 

549 __doc__ = _density_estimate_doc.format(estimator=density_name) 

550 

551 def _fit(self, dataset): 

552 from sklearn.mixture import GaussianMixture 

553 

554 density = GaussianMixture(n_components=10) 

555 density.fit(dataset.T) 

556 return density 

557 

558 def _evaluate(self, point): 

559 return np.squeeze(self.density.score_samples(np.atleast_2d(point))) 

560 

561 def _sample(self, nsamples=None): 

562 return np.squeeze(self.density.sample(n_samples=nsamples)[0]) 

563 

564 @staticmethod 

565 def check_dependencies(warn=True): 

566 if importlib.util.find_spec("sklearn") is None: 

567 if warn: 

568 logger.warning( 

569 "Unable to utilise GMMProposal as sklearn is not installed" 

570 ) 

571 return False 

572 else: 

573 return True 

574 

575 

576class NormalizingFlowProposal(DensityEstimateProposal): 

577 density_name = "Normalizing Flow" 

578 __doc__ = _density_estimate_doc.format(estimator=density_name) + ( 

579 """ 

580 js_factor: float 

581 The factor to use in determining the max-JS factor to terminate 

582 training. 

583 max_training_epochs: int 

584 The maximum bumber of traning steps to take 

585 """ 

586 ) 

587 

588 def __init__( 

589 self, 

590 priors, 

591 weight=1, 

592 subset=None, 

593 first_fit=1000, 

594 fit_multiplier=10, 

595 max_training_epochs=1000, 

596 scale_fits=1, 

597 nsamples_for_density=1000, 

598 js_factor=10, 

599 fallback=AdaptiveGaussianProposal, 

600 ): 

601 super(NormalizingFlowProposal, self).__init__( 

602 priors=priors, 

603 weight=weight, 

604 subset=subset, 

605 first_fit=first_fit, 

606 fit_multiplier=fit_multiplier, 

607 nsamples_for_density=nsamples_for_density, 

608 fallback=fallback, 

609 scale_fits=scale_fits, 

610 ) 

611 self.initialised = False 

612 self.max_training_epochs = max_training_epochs 

613 self.js_factor = js_factor 

614 

615 def initialise(self): 

616 self.setup_flow() 

617 self.setup_optimizer() 

618 self.initialised = True 

619 

620 def setup_flow(self): 

621 if self.ndim < 3: 

622 self.setup_basic_flow() 

623 else: 

624 self.setup_NVP_flow() 

625 

626 def setup_NVP_flow(self): 

627 from .flows import NVPFlow 

628 

629 self.flow = NVPFlow( 

630 features=self.ndim, 

631 hidden_features=self.ndim * 2, 

632 num_layers=2, 

633 num_blocks_per_layer=2, 

634 batch_norm_between_layers=True, 

635 batch_norm_within_layers=True, 

636 ) 

637 

638 def setup_basic_flow(self): 

639 from .flows import BasicFlow 

640 

641 self.flow = BasicFlow(features=self.ndim) 

642 

643 def setup_optimizer(self): 

644 from torch import optim 

645 

646 self.optimizer = optim.Adam(self.flow.parameters()) 

647 

648 def get_training_data(self, chain): 

649 training_data = [] 

650 nsamples_for_density = min(chain.position, self.nsamples_for_density) 

651 for _ in range(nsamples_for_density): 

652 s = chain.random_sample 

653 training_data.append([s[key] for key in self.parameters]) 

654 return training_data 

655 

656 def _calculate_js(self, validation_samples, training_samples_draw): 

657 # Calculate the maximum JS between the validation and draw 

658 max_js = 0 

659 for i in range(self.ndim): 

660 A = validation_samples[:, i] 

661 B = training_samples_draw[:, i] 

662 xmin = np.min([np.min(A), np.min(B)]) 

663 xmax = np.min([np.max(A), np.max(B)]) 

664 xval = np.linspace(xmin, xmax, 100) 

665 Apdf = gaussian_kde(A)(xval) 

666 Bpdf = gaussian_kde(B)(xval) 

667 js = jensenshannon(Apdf, Bpdf) 

668 max_js = max(max_js, js) 

669 return np.power(max_js, 2) 

670 

671 def train(self, chain): 

672 logger.debug("Starting NF training") 

673 

674 import torch 

675 

676 start = time.time() 

677 

678 training_samples = np.array(self.get_training_data(chain)) 

679 validation_samples = np.array(self.get_training_data(chain)) 

680 

681 training_tensor = torch.tensor(training_samples, dtype=torch.float32) 

682 

683 max_js_threshold = self.js_factor / self.nsamples_for_density 

684 

685 for epoch in range(1, self.max_training_epochs + 1): 

686 self.optimizer.zero_grad() 

687 loss = -self.flow.log_prob(inputs=training_tensor).mean() 

688 loss.backward() 

689 self.optimizer.step() 

690 

691 # Draw from the current flow 

692 self.flow.eval() 

693 training_samples_draw = ( 

694 self.flow.sample(self.nsamples_for_density).detach().numpy() 

695 ) 

696 self.flow.train() 

697 

698 if np.mod(epoch, 10) == 0: 

699 max_js_bits = self._calculate_js( 

700 validation_samples, training_samples_draw 

701 ) 

702 if max_js_bits < max_js_threshold: 

703 logger.debug( 

704 f"Training complete after {epoch} steps, " 

705 f"max_js_bits={max_js_bits:0.5f}<{max_js_threshold}" 

706 ) 

707 break 

708 

709 took = time.time() - start 

710 logger.debug( 

711 f"Flow training step ({self.steps_since_refit}) finished" 

712 f" for length {chain.position} chain, took {took:0.2f}s." 

713 f" Current accept-ratio={self.acceptance_ratio:0.2f}" 

714 ) 

715 self.steps_since_refit = 0 

716 self.next_refit_time *= self.fit_multiplier 

717 self.trained = True 

718 

719 def propose(self, chain): 

720 if self.initialised is False: 

721 self.initialise() 

722 

723 import torch 

724 

725 self.steps_since_refit += 1 

726 theta = chain.current_sample 

727 

728 # Check if we retrain the NF 

729 testA = self.steps_since_refit >= self.next_refit_time 

730 if testA: 

731 try: 

732 self.train(chain) 

733 except Exception as e: 

734 logger.warning(f"Failed to retrain chain due to error {e}") 

735 

736 if self.trained is False: 

737 return self.fallback.propose(chain) 

738 

739 self.flow.eval() 

740 theta_prime_T = self.flow.sample(1) 

741 

742 logp_theta_prime = self.flow.log_prob(theta_prime_T).detach().numpy()[0] 

743 theta_T = torch.tensor( 

744 np.atleast_2d([theta[key] for key in self.parameters]), dtype=torch.float32 

745 ) 

746 logp_theta = self.flow.log_prob(theta_T).detach().numpy()[0] 

747 log_factor = logp_theta - logp_theta_prime 

748 

749 flow_sample_values = np.atleast_1d(np.squeeze(theta_prime_T.detach().numpy())) 

750 for key, val in zip(self.parameters, flow_sample_values): 

751 theta[key] = val 

752 

753 return theta, float(log_factor) 

754 

755 @staticmethod 

756 def check_dependencies(warn=True): 

757 if importlib.util.find_spec("glasflow") is None: 

758 if warn: 

759 logger.warning( 

760 "Unable to utilise NormalizingFlowProposal as glasflow is not installed" 

761 ) 

762 return False 

763 else: 

764 return True 

765 

766 

767class FixedJumpProposal(BaseProposal): 

768 def __init__(self, priors, jumps=1, subset=None, weight=1, scale=1e-4): 

769 super(FixedJumpProposal, self).__init__(priors, weight, subset) 

770 self.scale = scale 

771 if isinstance(jumps, (int, float)): 

772 self.jumps = {key: jumps for key in self.parameters} 

773 elif isinstance(jumps, dict): 

774 self.jumps = jumps 

775 else: 

776 raise SamplerError("jumps not understood") 

777 

778 def propose(self, chain): 

779 sample = chain.current_sample 

780 for key, jump in self.jumps.items(): 

781 sign = random.rng.integers(2) * 2 - 1 

782 sample[key] += sign * jump + self.epsilon * self.prior_width_dict[key] 

783 log_factor = 0 

784 return sample, log_factor 

785 

786 @property 

787 def epsilon(self): 

788 return self.scale * random.rng.normal() 

789 

790 

791class FisherMatrixProposal(AdaptiveGaussianProposal): 

792 needs_likelihood_and_priors = True 

793 """Fisher Matrix Proposals 

794 

795 Uses a finite differencing approach motivated by BayesWave (see, e.g. 

796 https://arxiv.org/abs/1410.3835). The inverse Fisher Information Matrix 

797 is calculated from the current sample, then proposals are drawn from a 

798 multivariate Gaussian and scaled by an adaptive parameter. 

799 """ 

800 

801 def __init__( 

802 self, 

803 priors, 

804 subset=None, 

805 weight=1, 

806 update_interval=100, 

807 scale_init=1e0, 

808 fd_eps=1e-4, 

809 adapt=False, 

810 ): 

811 super(FisherMatrixProposal, self).__init__( 

812 priors, weight, subset, scale_init=scale_init 

813 ) 

814 self.update_interval = update_interval 

815 self.steps_since_update = update_interval 

816 self.adapt = adapt 

817 self.mean = np.zeros(len(self.parameters)) 

818 self.fd_eps = fd_eps 

819 

820 def propose(self, chain, likelihood, priors): 

821 sample = chain.current_sample 

822 if self.adapt: 

823 self.update_scale(chain) 

824 if self.steps_since_update >= self.update_interval: 

825 fmp = FisherMatrixPosteriorEstimator( 

826 likelihood, priors, parameters=self.parameters, fd_eps=self.fd_eps 

827 ) 

828 try: 

829 self.iFIM = fmp.calculate_iFIM(sample.dict) 

830 except (RuntimeError, ValueError, np.linalg.LinAlgError) as e: 

831 logger.warning(f"FisherMatrixProposal failed with {e}") 

832 if hasattr(self, "iFIM") is False: 

833 # No past iFIM exists, return sample 

834 return sample, 0 

835 self.steps_since_update = 0 

836 

837 jump = self.scale * random.rng.multivariate_normal( 

838 self.mean, self.iFIM, check_valid="ignore" 

839 ) 

840 

841 for key, val in zip(self.parameters, jump): 

842 sample[key] += val 

843 

844 log_factor = 0 

845 self.steps_since_update += 1 

846 return sample, log_factor 

847 

848 

849class BaseGravitationalWaveTransientProposal(BaseProposal): 

850 def __init__(self, priors, weight=1): 

851 super(BaseGravitationalWaveTransientProposal, self).__init__( 

852 priors, weight=weight 

853 ) 

854 if "phase" in priors: 

855 self.phase_key = "phase" 

856 elif "delta_phase" in priors: 

857 self.phase_key = "delta_phase" 

858 else: 

859 self.phase_key = None 

860 

861 def get_cos_theta_jn(self, sample): 

862 if "cos_theta_jn" in sample.parameter_keys: 

863 cos_theta_jn = sample["cos_theta_jn"] 

864 elif "theta_jn" in sample.parameter_keys: 

865 cos_theta_jn = np.cos(sample["theta_jn"]) 

866 else: 

867 raise SamplerError() 

868 return cos_theta_jn 

869 

870 def get_phase(self, sample): 

871 if "phase" in sample.parameter_keys: 

872 return sample["phase"] 

873 elif "delta_phase" in sample.parameter_keys: 

874 cos_theta_jn = self.get_cos_theta_jn(sample) 

875 delta_phase = sample["delta_phase"] 

876 psi = sample["psi"] 

877 phase = np.mod(delta_phase - np.sign(cos_theta_jn) * psi, 2 * np.pi) 

878 else: 

879 raise SamplerError() 

880 return phase 

881 

882 def get_delta_phase(self, phase, sample): 

883 cos_theta_jn = self.get_cos_theta_jn(sample) 

884 psi = sample["psi"] 

885 delta_phase = phase + np.sign(cos_theta_jn) * psi 

886 return delta_phase 

887 

888 

889class CorrelatedPolarisationPhaseJump(BaseGravitationalWaveTransientProposal): 

890 def __init__(self, priors, weight=1): 

891 super(CorrelatedPolarisationPhaseJump, self).__init__(priors, weight=weight) 

892 

893 def propose(self, chain): 

894 sample = chain.current_sample 

895 phase = self.get_phase(sample) 

896 

897 alpha = sample["psi"] + phase 

898 beta = sample["psi"] - phase 

899 

900 draw = random.rng.random() 

901 if draw < 0.5: 

902 alpha = 3.0 * np.pi * random.rng.random() 

903 else: 

904 beta = 3.0 * np.pi * random.rng.random() - 2 * np.pi 

905 

906 # Update 

907 sample["psi"] = (alpha + beta) * 0.5 

908 phase = (alpha - beta) * 0.5 

909 

910 if self.phase_key == "delta_phase": 

911 sample["delta_phase"] = self.get_delta_phase(phase, sample) 

912 else: 

913 sample["phase"] = phase 

914 

915 log_factor = 0 

916 return sample, log_factor 

917 

918 

919class PhaseReversalProposal(BaseGravitationalWaveTransientProposal): 

920 def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): 

921 super(PhaseReversalProposal, self).__init__(priors, weight) 

922 self.fuzz = fuzz 

923 self.fuzz_sigma = fuzz_sigma 

924 if self.phase_key is None: 

925 raise SamplerError( 

926 f"{type(self).__name__} initialised without a phase prior" 

927 ) 

928 

929 def propose(self, chain): 

930 sample = chain.current_sample 

931 phase = sample[self.phase_key] 

932 sample[self.phase_key] = np.mod(phase + np.pi + self.epsilon, 2 * np.pi) 

933 log_factor = 0 

934 return sample, log_factor 

935 

936 @property 

937 def epsilon(self): 

938 if self.fuzz: 

939 return random.rng.normal(0, self.fuzz_sigma) 

940 else: 

941 return 0 

942 

943 

944class PolarisationReversalProposal(PhaseReversalProposal): 

945 def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-3): 

946 super(PolarisationReversalProposal, self).__init__( 

947 priors, weight, fuzz, fuzz_sigma 

948 ) 

949 self.fuzz = fuzz 

950 

951 def propose(self, chain): 

952 sample = chain.current_sample 

953 psi = sample["psi"] 

954 sample["psi"] = np.mod(psi + np.pi / 2 + self.epsilon, np.pi) 

955 log_factor = 0 

956 return sample, log_factor 

957 

958 

959class PhasePolarisationReversalProposal(PhaseReversalProposal): 

960 def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): 

961 super(PhasePolarisationReversalProposal, self).__init__( 

962 priors, weight, fuzz, fuzz_sigma 

963 ) 

964 self.fuzz = fuzz 

965 

966 def propose(self, chain): 

967 sample = chain.current_sample 

968 sample[self.phase_key] = np.mod( 

969 sample[self.phase_key] + np.pi + self.epsilon, 2 * np.pi 

970 ) 

971 sample["psi"] = np.mod(sample["psi"] + np.pi / 2 + self.epsilon, np.pi) 

972 log_factor = 0 

973 return sample, log_factor 

974 

975 

976class StretchProposal(BaseProposal): 

977 """The Goodman & Weare (2010) Stretch proposal for an MCMC chain 

978 

979 Implementation of the Stretch proposal using a sample drawn from the chain. 

980 We assume the form of g(z) from Equation (9) of [1]. 

981 

982 References 

983 ---------- 

984 [1] Goodman & Weare (2010) 

985 https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract 

986 

987 """ 

988 

989 def __init__(self, priors, weight=1, subset=None, scale=2): 

990 super(StretchProposal, self).__init__(priors, weight, subset) 

991 self.scale = scale 

992 

993 def propose(self, chain): 

994 sample = chain.current_sample 

995 

996 # Draw a random sample 

997 rand = chain.random_sample 

998 

999 return _stretch_move(sample, rand, self.scale, self.ndim, self.parameters) 

1000 

1001 

1002def _stretch_move(sample, complement, scale, ndim, parameters): 

1003 # Draw z 

1004 u = random.rng.uniform(0, 1) 

1005 z = (u * (scale - 1) + 1) ** 2 / scale 

1006 

1007 log_factor = (ndim - 1) * np.log(z) 

1008 

1009 for key in parameters: 

1010 sample[key] = complement[key] + (sample[key] - complement[key]) * z 

1011 

1012 return sample, log_factor 

1013 

1014 

1015class EnsembleProposal(BaseProposal): 

1016 """Base EnsembleProposal class for ensemble-based swap proposals""" 

1017 

1018 def __init__(self, priors, weight=1): 

1019 super(EnsembleProposal, self).__init__(priors, weight) 

1020 

1021 def __call__(self, chain, chain_complement): 

1022 sample, log_factor = self.propose(chain, chain_complement) 

1023 if log_factor == 0: 

1024 sample = self.apply_boundaries(sample) 

1025 return sample, log_factor 

1026 

1027 

1028class EnsembleStretch(EnsembleProposal): 

1029 """The Goodman & Weare (2010) Stretch proposal for an Ensemble 

1030 

1031 Implementation of the Stretch proposal using a sample drawn from complement. 

1032 We assume the form of g(z) from Equation (9) of [1]. 

1033 

1034 References 

1035 ---------- 

1036 [1] Goodman & Weare (2010) 

1037 https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract 

1038 

1039 """ 

1040 

1041 def __init__(self, priors, weight=1, scale=2): 

1042 super(EnsembleStretch, self).__init__(priors, weight) 

1043 self.scale = scale 

1044 

1045 def propose(self, chain, chain_complement): 

1046 sample = chain.current_sample 

1047 completement = chain_complement[ 

1048 random.rng.integers(len(chain_complement)) 

1049 ].current_sample 

1050 return _stretch_move( 

1051 sample, completement, self.scale, self.ndim, self.parameters 

1052 ) 

1053 

1054 

1055def get_default_ensemble_proposal_cycle(priors): 

1056 return ProposalCycle([EnsembleStretch(priors)]) 

1057 

1058 

1059def get_proposal_cycle(string, priors, L1steps=1, warn=True): 

1060 big_weight = 10 

1061 small_weight = 5 

1062 tiny_weight = 0.5 

1063 

1064 if "gwA" in string: 

1065 # Parameters for learning proposals 

1066 learning_kwargs = dict( 

1067 first_fit=1000, nsamples_for_density=10000, fit_multiplier=2 

1068 ) 

1069 

1070 all_but_cal = [key for key in priors if "recalib" not in key] 

1071 plist = [ 

1072 AdaptiveGaussianProposal(priors, weight=small_weight, subset=all_but_cal), 

1073 DifferentialEvolutionProposal( 

1074 priors, weight=small_weight, subset=all_but_cal 

1075 ), 

1076 ] 

1077 

1078 if GMMProposal.check_dependencies(warn=warn) is False: 

1079 raise SamplerError( 

1080 "the gwA proposal_cycle required the GMMProposal dependencies" 

1081 ) 

1082 

1083 if priors.intrinsic: 

1084 intrinsic = PARAMETER_SETS["intrinsic"] 

1085 plist += [ 

1086 AdaptiveGaussianProposal(priors, weight=small_weight, subset=intrinsic), 

1087 DifferentialEvolutionProposal( 

1088 priors, weight=small_weight, subset=intrinsic 

1089 ), 

1090 KDEProposal( 

1091 priors, weight=small_weight, subset=intrinsic, **learning_kwargs 

1092 ), 

1093 GMMProposal( 

1094 priors, weight=small_weight, subset=intrinsic, **learning_kwargs 

1095 ), 

1096 ] 

1097 

1098 if priors.extrinsic: 

1099 extrinsic = PARAMETER_SETS["extrinsic"] 

1100 plist += [ 

1101 AdaptiveGaussianProposal(priors, weight=small_weight, subset=extrinsic), 

1102 DifferentialEvolutionProposal( 

1103 priors, weight=small_weight, subset=extrinsic 

1104 ), 

1105 KDEProposal( 

1106 priors, weight=small_weight, subset=extrinsic, **learning_kwargs 

1107 ), 

1108 GMMProposal( 

1109 priors, weight=small_weight, subset=extrinsic, **learning_kwargs 

1110 ), 

1111 ] 

1112 

1113 if priors.mass: 

1114 mass = PARAMETER_SETS["mass"] 

1115 plist += [ 

1116 DifferentialEvolutionProposal(priors, weight=small_weight, subset=mass), 

1117 GMMProposal( 

1118 priors, weight=small_weight, subset=mass, **learning_kwargs 

1119 ), 

1120 FisherMatrixProposal( 

1121 priors, 

1122 weight=small_weight, 

1123 subset=mass, 

1124 ), 

1125 ] 

1126 

1127 if priors.spin: 

1128 spin = PARAMETER_SETS["spin"] 

1129 plist += [ 

1130 DifferentialEvolutionProposal(priors, weight=small_weight, subset=spin), 

1131 GMMProposal( 

1132 priors, weight=small_weight, subset=spin, **learning_kwargs 

1133 ), 

1134 FisherMatrixProposal( 

1135 priors, 

1136 weight=big_weight, 

1137 subset=spin, 

1138 ), 

1139 ] 

1140 if priors.measured_spin: 

1141 measured_spin = PARAMETER_SETS["measured_spin"] 

1142 plist += [ 

1143 AdaptiveGaussianProposal( 

1144 priors, weight=small_weight, subset=measured_spin 

1145 ), 

1146 FisherMatrixProposal( 

1147 priors, 

1148 weight=small_weight, 

1149 subset=measured_spin, 

1150 ), 

1151 ] 

1152 

1153 if priors.mass and priors.spin: 

1154 primary_spin_and_q = PARAMETER_SETS["primary_spin_and_q"] 

1155 plist += [ 

1156 DifferentialEvolutionProposal( 

1157 priors, weight=small_weight, subset=primary_spin_and_q 

1158 ), 

1159 ] 

1160 

1161 if getattr(priors, "tidal", False): 

1162 tidal = PARAMETER_SETS["tidal"] 

1163 plist += [ 

1164 DifferentialEvolutionProposal( 

1165 priors, weight=small_weight, subset=tidal 

1166 ), 

1167 PriorProposal(priors, weight=small_weight, subset=tidal), 

1168 ] 

1169 if priors.phase: 

1170 plist += [ 

1171 PhaseReversalProposal(priors, weight=tiny_weight), 

1172 ] 

1173 if priors.phase and "psi" in priors.non_fixed_keys: 

1174 plist += [ 

1175 CorrelatedPolarisationPhaseJump(priors, weight=tiny_weight), 

1176 PhasePolarisationReversalProposal(priors, weight=tiny_weight), 

1177 ] 

1178 if priors.sky: 

1179 sky = PARAMETER_SETS["sky"] 

1180 plist += [ 

1181 FisherMatrixProposal( 

1182 priors, 

1183 weight=small_weight, 

1184 subset=sky, 

1185 ), 

1186 GMMProposal( 

1187 priors, 

1188 weight=small_weight, 

1189 subset=sky, 

1190 **learning_kwargs, 

1191 ), 

1192 ] 

1193 for key in ["time_jitter", "psi", "phi_12", "tilt_2", "lambda_1", "lambda_2"]: 

1194 if key in priors.non_fixed_keys: 

1195 plist.append(PriorProposal(priors, subset=[key], weight=tiny_weight)) 

1196 if "chi_1_in_plane" in priors and "chi_2_in_plane" in priors: 

1197 in_plane = ["chi_1_in_plane", "chi_2_in_plane", "phi_12"] 

1198 plist.append(UniformProposal(priors, subset=in_plane, weight=tiny_weight)) 

1199 if any("recalib_" in key for key in priors): 

1200 calibration = [key for key in priors if "recalib_" in key] 

1201 plist.append(PriorProposal(priors, subset=calibration, weight=small_weight)) 

1202 else: 

1203 plist = [ 

1204 AdaptiveGaussianProposal(priors, weight=big_weight), 

1205 DifferentialEvolutionProposal(priors, weight=big_weight), 

1206 UniformProposal(priors, weight=tiny_weight), 

1207 KDEProposal(priors, weight=big_weight, scale_fits=L1steps), 

1208 FisherMatrixProposal(priors, weight=big_weight), 

1209 ] 

1210 if GMMProposal.check_dependencies(warn=warn): 

1211 plist.append(GMMProposal(priors, weight=big_weight, scale_fits=L1steps)) 

1212 

1213 plist = remove_proposals_using_string(plist, string) 

1214 return ProposalCycle(plist) 

1215 

1216 

1217def remove_proposals_using_string(plist, string): 

1218 mapping = dict( 

1219 DE=DifferentialEvolutionProposal, 

1220 AG=AdaptiveGaussianProposal, 

1221 ST=StretchProposal, 

1222 FG=FixedGaussianProposal, 

1223 NF=NormalizingFlowProposal, 

1224 KD=KDEProposal, 

1225 GM=GMMProposal, 

1226 PR=PriorProposal, 

1227 UN=UniformProposal, 

1228 FM=FisherMatrixProposal, 

1229 ) 

1230 

1231 for element in string.split("no")[1:]: 

1232 if element in mapping: 

1233 plist = [p for p in plist if isinstance(p, mapping[element]) is False] 

1234 return plist