Coverage for bilby/core/prior/joint.py: 85%

408 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import re 

2 

3import numpy as np 

4import scipy.stats 

5from scipy.special import erfinv 

6 

7from .base import Prior, PriorException 

8from ..utils import logger, infer_args_from_method, get_dict_with_properties 

9from ..utils import random 

10 

11 

12class BaseJointPriorDist(object): 

13 def __init__(self, names, bounds=None): 

14 """ 

15 A class defining JointPriorDist that will be overwritten with child 

16 classes defining the joint prior distributions between given parameters, 

17 

18 

19 Parameters 

20 ========== 

21 names: list (required) 

22 A list of the parameter names in the JointPriorDist. The 

23 listed parameters must have the same order that they appear in 

24 the lists of statistical parameters that may be passed in child class 

25 bounds: list (optional) 

26 A list of bounds on each parameter. The defaults are for bounds at 

27 +/- infinity. 

28 """ 

29 self.distname = "joint_dist" 

30 if not isinstance(names, list): 

31 self.names = [names] 

32 else: 

33 self.names = names 

34 

35 self.num_vars = len(self.names) 

36 

37 # set the bounds for each parameter 

38 if isinstance(bounds, list): 

39 if len(bounds) != len(self): 

40 raise ValueError("Wrong number of parameter bounds") 

41 

42 # check bounds 

43 for bound in bounds: 

44 if isinstance(bounds, (list, tuple, np.ndarray)): 

45 if len(bound) != 2: 

46 raise ValueError( 

47 "Bounds must contain an upper and lower value." 

48 ) 

49 else: 

50 if bound[1] <= bound[0]: 

51 raise ValueError("Bounds are not properly set") 

52 else: 

53 raise TypeError("Bound must be a list") 

54 else: 

55 bounds = [(-np.inf, np.inf) for _ in self.names] 

56 self.bounds = {name: val for name, val in zip(self.names, bounds)} 

57 

58 self._current_sample = {} # initialise empty sample 

59 self._uncorrelated = None 

60 self._current_lnprob = None 

61 

62 # a dictionary of the parameters as requested by the prior 

63 self.requested_parameters = dict() 

64 self.reset_request() 

65 

66 # a dictionary of the rescaled parameters 

67 self.rescale_parameters = dict() 

68 self.reset_rescale() 

69 

70 # a list of sampled parameters 

71 self.reset_sampled() 

72 

73 def reset_sampled(self): 

74 self.sampled_parameters = [] 

75 self.current_sample = {} 

76 

77 def filled_request(self): 

78 """ 

79 Check if all requested parameters have been filled. 

80 """ 

81 

82 return not np.any([val is None for val in self.requested_parameters.values()]) 

83 

84 def reset_request(self): 

85 """ 

86 Reset the requested parameters to None. 

87 """ 

88 

89 for name in self.names: 

90 self.requested_parameters[name] = None 

91 

92 def filled_rescale(self): 

93 """ 

94 Check if all the rescaled parameters have been filled. 

95 """ 

96 

97 return not np.any([val is None for val in self.rescale_parameters.values()]) 

98 

99 def reset_rescale(self): 

100 """ 

101 Reset the rescaled parameters to None. 

102 """ 

103 

104 for name in self.names: 

105 self.rescale_parameters[name] = None 

106 

107 def get_instantiation_dict(self): 

108 subclass_args = infer_args_from_method(self.__init__) 

109 dict_with_properties = get_dict_with_properties(self) 

110 instantiation_dict = dict() 

111 for key in subclass_args: 

112 if isinstance(dict_with_properties[key], list): 

113 value = np.asarray(dict_with_properties[key]).tolist() 

114 else: 

115 value = dict_with_properties[key] 

116 instantiation_dict[key] = value 

117 return instantiation_dict 

118 

119 def __len__(self): 

120 return len(self.names) 

121 

122 def __repr__(self): 

123 """Overrides the special method __repr__. 

124 

125 Returns a representation of this instance that resembles how it is instantiated. 

126 Works correctly for all child classes 

127 

128 Returns 

129 ======= 

130 str: A string representation of this instance 

131 

132 """ 

133 dist_name = self.__class__.__name__ 

134 instantiation_dict = self.get_instantiation_dict() 

135 args = ", ".join( 

136 [ 

137 "{}={}".format(key, repr(instantiation_dict[key])) 

138 for key in instantiation_dict 

139 ] 

140 ) 

141 return "{}({})".format(dist_name, args) 

142 

143 def prob(self, samp): 

144 """ 

145 Get the probability of a sample. For bounded priors the 

146 probability will not be properly normalised. 

147 """ 

148 

149 return np.exp(self.ln_prob(samp)) 

150 

151 def _check_samp(self, value): 

152 """ 

153 Get the log-probability of a sample. For bounded priors the 

154 probability will not be properly normalised. 

155 

156 Parameters 

157 ========== 

158 value: array_like 

159 A 1d vector of the sample, or 2d array of sample values with shape 

160 NxM, where N is the number of samples and M is the number of 

161 parameters. 

162 

163 Returns 

164 ======= 

165 samp: array_like 

166 returns the input value as a sample array 

167 outbounds: array_like 

168 Boolean Array that selects samples in samp that are out of given bounds 

169 """ 

170 samp = np.array(value) 

171 if len(samp.shape) == 1: 

172 samp = samp.reshape(1, self.num_vars) 

173 

174 if len(samp.shape) != 2: 

175 raise ValueError("Array is the wrong shape") 

176 elif samp.shape[1] != self.num_vars: 

177 raise ValueError("Array is the wrong shape") 

178 

179 # check sample(s) is within bounds 

180 outbounds = np.ones(samp.shape[0], dtype=bool) 

181 for s, bound in zip(samp.T, self.bounds.values()): 

182 outbounds = (s < bound[0]) | (s > bound[1]) 

183 if np.any(outbounds): 

184 break 

185 return samp, outbounds 

186 

187 def ln_prob(self, value): 

188 """ 

189 Get the log-probability of a sample. For bounded priors the 

190 probability will not be properly normalised. 

191 

192 Parameters 

193 ========== 

194 value: array_like 

195 A 1d vector of the sample, or 2d array of sample values with shape 

196 NxM, where N is the number of samples and M is the number of 

197 parameters. 

198 """ 

199 

200 samp, outbounds = self._check_samp(value) 

201 lnprob = -np.inf * np.ones(samp.shape[0]) 

202 lnprob = self._ln_prob(samp, lnprob, outbounds) 

203 if samp.shape[0] == 1: 

204 return lnprob[0] 

205 else: 

206 return lnprob 

207 

208 def _ln_prob(self, samp, lnprob, outbounds): 

209 """ 

210 Get the log-probability of a sample. For bounded priors the 

211 probability will not be properly normalised. **this method needs overwritten by child class** 

212 

213 Parameters 

214 ========== 

215 samp: vector 

216 sample to evaluate the ln_prob at 

217 lnprob: vector 

218 of -inf passed in with the same shape as the number of samples 

219 outbounds: array_like 

220 boolean array showing which samples in lnprob vector are out of the given bounds 

221 

222 Returns 

223 ======= 

224 lnprob: vector 

225 array of lnprob values for each sample given 

226 """ 

227 """ 

228 Here is where the subclass where overwrite ln_prob method 

229 """ 

230 return lnprob 

231 

232 def sample(self, size=1, **kwargs): 

233 """ 

234 Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten 

235 

236 Parameters 

237 ========== 

238 size: int 

239 number of samples to generate, defaults to 1 

240 """ 

241 

242 if size is None: 

243 size = 1 

244 samps = self._sample(size=size, **kwargs) 

245 for i, name in enumerate(self.names): 

246 if size == 1: 

247 self.current_sample[name] = samps[:, i].flatten()[0] 

248 else: 

249 self.current_sample[name] = samps[:, i].flatten() 

250 

251 def _sample(self, size, **kwargs): 

252 """ 

253 Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**) 

254 

255 Parameters 

256 ========== 

257 size: int 

258 number of samples to generate, defaults to 1 

259 """ 

260 samps = np.zeros((size, len(self))) 

261 """ 

262 Here is where the subclass where overwrite sampling method 

263 """ 

264 return samps 

265 

266 def rescale(self, value, **kwargs): 

267 """ 

268 Rescale from a unit hypercube to JointPriorDist. Note that no 

269 bounds are applied in the rescale function. (child classes need to 

270 overwrite accompanying method _rescale(). 

271 

272 Parameters 

273 ========== 

274 value: array 

275 A 1d vector sample (one for each parameter) drawn from a uniform 

276 distribution between 0 and 1, or a 2d NxM array of samples where 

277 N is the number of samples and M is the number of parameters. 

278 kwargs: dict 

279 All keyword args that need to be passed to _rescale method, these keyword 

280 args are called in the JointPrior rescale methods for each parameter 

281 

282 Returns 

283 ======= 

284 array: 

285 An vector sample drawn from the multivariate Gaussian 

286 distribution. 

287 """ 

288 samp = np.array(value) 

289 if len(samp.shape) == 1: 

290 samp = samp.reshape(1, self.num_vars) 

291 

292 if len(samp.shape) != 2: 

293 raise ValueError("Array is the wrong shape") 

294 elif samp.shape[1] != self.num_vars: 

295 raise ValueError("Array is the wrong shape") 

296 

297 samp = self._rescale(samp, **kwargs) 

298 return np.squeeze(samp) 

299 

300 def _rescale(self, samp, **kwargs): 

301 """ 

302 rescale a sample from a unit hypercybe to the joint dist (**needs to be ovewritten by child class**) 

303 

304 Parameters 

305 ========== 

306 samp: numpy array 

307 this is a vector sample drawn from a uniform distribution to be rescaled to the distribution 

308 """ 

309 """ 

310 Here is where the subclass where overwrite rescale method 

311 """ 

312 return samp 

313 

314 

315class MultivariateGaussianDist(BaseJointPriorDist): 

316 def __init__( 

317 self, 

318 names, 

319 nmodes=1, 

320 mus=None, 

321 sigmas=None, 

322 corrcoefs=None, 

323 covs=None, 

324 weights=None, 

325 bounds=None, 

326 ): 

327 """ 

328 A class defining a multi-variate Gaussian, allowing multiple modes for 

329 a Gaussian mixture model. 

330 

331 Note: if using a multivariate Gaussian prior, with bounds, this can 

332 lead to biases in the marginal likelihood estimate and posterior 

333 estimate for nested samplers routines that rely on sampling from a unit 

334 hypercube and having a prior transform, e.g., nestle, dynesty and 

335 MultiNest. 

336 

337 Parameters 

338 ========== 

339 names: list 

340 A list of the parameter names in the multivariate Gaussian. The 

341 listed parameters must have the same order that they appear in 

342 the lists of means, standard deviations, and the correlation 

343 coefficient, or covariance, matrices. 

344 nmodes: int 

345 The number of modes for the mixture model. This defaults to 1, 

346 which will be checked against the shape of the other inputs. 

347 mus: array_like 

348 A list of lists of means of each mode in a multivariate Gaussian 

349 mixture model. A single list can be given for a single mode. If 

350 this is None then means at zero will be assumed. 

351 sigmas: array_like 

352 A list of lists of the standard deviations of each mode of the 

353 multivariate Gaussian. If supplying a correlation coefficient 

354 matrix rather than a covariance matrix these values must be given. 

355 If this is None unit variances will be assumed. 

356 corrcoefs: array 

357 A list of square matrices containing the correlation coefficients 

358 of the parameters for each mode. If this is None it will be assumed 

359 that the parameters are uncorrelated. 

360 covs: array 

361 A list of square matrices containing the covariance matrix of the 

362 multivariate Gaussian. 

363 weights: list 

364 A list of weights (relative probabilities) for each mode of the 

365 multivariate Gaussian. This will default to equal weights for each 

366 mode. 

367 bounds: list 

368 A list of bounds on each parameter. The defaults are for bounds at 

369 +/- infinity. 

370 """ 

371 super(MultivariateGaussianDist, self).__init__(names=names, bounds=bounds) 

372 for name in self.names: 

373 bound = self.bounds[name] 

374 if bound[0] != -np.inf or bound[1] != np.inf: 

375 logger.warning( 

376 "If using bounded ranges on the multivariate " 

377 "Gaussian this will lead to biased posteriors " 

378 "for nested sampling routines that require " 

379 "a prior transform." 

380 ) 

381 self.distname = "mvg" 

382 self.mus = [] 

383 self.covs = [] 

384 self.corrcoefs = [] 

385 self.sigmas = [] 

386 self.logprodsigmas = [] # log of product of sigmas, needed for "standard" multivariate normal 

387 self.weights = [] 

388 self.eigvalues = [] 

389 self.eigvectors = [] 

390 self.sqeigvalues = [] # square root of the eigenvalues 

391 self.mvn = [] # list of multivariate normal distributions 

392 

393 # put values in lists if required 

394 if nmodes == 1: 

395 if mus is not None: 

396 if len(np.shape(mus)) == 1: 

397 mus = [mus] 

398 elif len(np.shape(mus)) == 0: 

399 raise ValueError("Must supply a list of means") 

400 if sigmas is not None: 

401 if len(np.shape(sigmas)) == 1: 

402 sigmas = [sigmas] 

403 elif len(np.shape(sigmas)) == 0: 

404 raise ValueError("Must supply a list of standard deviations") 

405 if covs is not None: 

406 if isinstance(covs, np.ndarray): 

407 covs = [covs] 

408 elif isinstance(covs, list): 

409 if len(np.shape(covs)) == 2: 

410 covs = [np.array(covs)] 

411 elif len(np.shape(covs)) != 3: 

412 raise TypeError("List of covariances the wrong shape") 

413 else: 

414 raise TypeError("Must pass a list of covariances") 

415 if corrcoefs is not None: 

416 if isinstance(corrcoefs, np.ndarray): 

417 corrcoefs = [corrcoefs] 

418 elif isinstance(corrcoefs, list): 

419 if len(np.shape(corrcoefs)) == 2: 

420 corrcoefs = [np.array(corrcoefs)] 

421 elif len(np.shape(corrcoefs)) != 3: 

422 raise TypeError( 

423 "List of correlation coefficients the wrong shape" 

424 ) 

425 elif not isinstance(corrcoefs, list): 

426 raise TypeError("Must pass a list of correlation coefficients") 

427 if weights is not None: 

428 if isinstance(weights, (int, float)): 

429 weights = [weights] 

430 elif isinstance(weights, list): 

431 if len(weights) != 1: 

432 raise ValueError("Wrong number of weights given") 

433 

434 for val in [mus, sigmas, covs, corrcoefs, weights]: 

435 if val is not None and not isinstance(val, list): 

436 raise TypeError("Value must be a list") 

437 else: 

438 if val is not None and len(val) != nmodes: 

439 raise ValueError("Wrong number of modes given") 

440 

441 # add the modes 

442 self.nmodes = 0 

443 for i in range(nmodes): 

444 mu = mus[i] if mus is not None else None 

445 sigma = sigmas[i] if sigmas is not None else None 

446 corrcoef = corrcoefs[i] if corrcoefs is not None else None 

447 cov = covs[i] if covs is not None else None 

448 weight = weights[i] if weights is not None else 1.0 

449 

450 self.add_mode(mu, sigma, corrcoef, cov, weight) 

451 

452 def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): 

453 """ 

454 Add a new mode. 

455 """ 

456 

457 # add means 

458 if mus is not None: 

459 try: 

460 self.mus.append(list(mus)) # means 

461 except TypeError: 

462 raise TypeError("'mus' must be a list") 

463 else: 

464 self.mus.append(np.zeros(self.num_vars)) 

465 

466 # add the covariances if supplied 

467 if cov is not None: 

468 self.covs.append(np.asarray(cov)) 

469 

470 if len(self.covs[-1].shape) != 2: 

471 raise ValueError("Covariance matrix must be a 2d array") 

472 

473 if ( 

474 self.covs[-1].shape[0] != self.covs[-1].shape[1] 

475 or self.covs[-1].shape[0] != self.num_vars 

476 ): 

477 raise ValueError("Covariance shape is inconsistent") 

478 

479 # check matrix is symmetric 

480 if not np.allclose(self.covs[-1], self.covs[-1].T): 

481 raise ValueError("Covariance matrix is not symmetric") 

482 

483 self.sigmas.append(np.sqrt(np.diag(self.covs[-1]))) # standard deviations 

484 

485 # convert covariance into a correlation coefficient matrix 

486 D = self.sigmas[-1] * np.identity(self.covs[-1].shape[0]) 

487 Dinv = np.linalg.inv(D) 

488 self.corrcoefs.append(np.dot(np.dot(Dinv, self.covs[-1]), Dinv)) 

489 elif corrcoef is not None and sigmas is not None: 

490 self.corrcoefs.append(np.asarray(corrcoef)) 

491 

492 if len(self.corrcoefs[-1].shape) != 2: 

493 raise ValueError( 

494 "Correlation coefficient matrix must be a 2d array." 

495 ) 

496 

497 if ( 

498 self.corrcoefs[-1].shape[0] != self.corrcoefs[-1].shape[1] 

499 or self.corrcoefs[-1].shape[0] != self.num_vars 

500 ): 

501 raise ValueError( 

502 "Correlation coefficient matrix shape is inconsistent" 

503 ) 

504 

505 # check matrix is symmetric 

506 if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T): 

507 raise ValueError("Correlation coefficient matrix is not symmetric") 

508 

509 # check diagonal is all ones 

510 if not np.all(np.diag(self.corrcoefs[-1]) == 1.0): 

511 raise ValueError("Correlation coefficient matrix is not correct") 

512 

513 try: 

514 self.sigmas.append(list(sigmas)) # standard deviations 

515 except TypeError: 

516 raise TypeError("'sigmas' must be a list") 

517 

518 if len(self.sigmas[-1]) != self.num_vars: 

519 raise ValueError( 

520 "Number of standard deviations must be the " 

521 "same as the number of parameters." 

522 ) 

523 

524 # convert correlation coefficients to covariance matrix 

525 D = self.sigmas[-1] * np.identity(self.corrcoefs[-1].shape[0]) 

526 self.covs.append(np.dot(D, np.dot(self.corrcoefs[-1], D))) 

527 else: 

528 # set unit variance uncorrelated covariance 

529 self.corrcoefs.append(np.eye(self.num_vars)) 

530 self.covs.append(np.eye(self.num_vars)) 

531 self.sigmas.append(np.ones(self.num_vars)) 

532 

533 # compute log of product of sigmas, needed for "standard" multivariate normal 

534 self.logprodsigmas.append(np.log(np.prod(self.sigmas[-1]))) 

535 

536 # get eigen values and vectors 

537 try: 

538 evals, evecs = np.linalg.eig(self.corrcoefs[-1]) 

539 self.eigvalues.append(evals) 

540 self.eigvectors.append(evecs) 

541 except Exception as e: 

542 raise RuntimeError( 

543 "Problem getting eigenvalues and vectors: {}".format(e) 

544 ) 

545 

546 # check eigenvalues are positive 

547 if np.any(self.eigvalues[-1] <= 0.0): 

548 raise ValueError( 

549 "Correlation coefficient matrix is not positive definite" 

550 ) 

551 self.sqeigvalues.append(np.sqrt(self.eigvalues[-1])) 

552 

553 # set the weights 

554 if weight is None: 

555 self.weights.append(1.0) 

556 else: 

557 self.weights.append(weight) 

558 

559 # set the cumulative relative weights 

560 self.cumweights = np.cumsum(self.weights) / np.sum(self.weights) 

561 

562 # add the mode 

563 self.nmodes += 1 

564 

565 # add "standard" multivariate normal distribution 

566 # - when the typical scales of the parameters are very different, 

567 # multivariate_normal() may complain that the covariance matrix is singular 

568 # - instead pass zero means and correlation matrix instead of covariance matrix 

569 # to get the equivalent of a standard normal distribution in higher dimensions 

570 # - this modifies the multivariate normal PDF as follows: 

571 # multivariate_normal(mean=mus, cov=cov).logpdf(x) 

572 # = multivariate_normal(mean=0, cov=corrcoefs).logpdf((x - mus)/sigmas) - logprodsigmas 

573 self.mvn.append( 

574 scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1]) 

575 ) 

576 

577 def _rescale(self, samp, **kwargs): 

578 try: 

579 mode = kwargs["mode"] 

580 except KeyError: 

581 mode = None 

582 

583 if mode is None: 

584 if self.nmodes == 1: 

585 mode = 0 

586 else: 

587 mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] 

588 

589 samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 

590 

591 # rotate and scale to the multivariate normal shape 

592 samp = self.mus[mode] + self.sigmas[mode] * np.einsum( 

593 "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] 

594 ) 

595 return samp 

596 

597 def _sample(self, size, **kwargs): 

598 try: 

599 mode = kwargs["mode"] 

600 except KeyError: 

601 mode = None 

602 

603 if mode is None: 

604 if self.nmodes == 1: 

605 mode = 0 

606 else: 

607 if size == 1: 

608 mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] 

609 else: 

610 # pick modes 

611 mode = [ 

612 np.argwhere(self.cumweights - r > 0)[0][0] 

613 for r in random.rng.uniform(0, 1, size) 

614 ] 

615 

616 samps = np.zeros((size, len(self))) 

617 for i in range(size): 

618 inbound = False 

619 while not inbound: 

620 # sample the multivariate Gaussian keys 

621 vals = random.rng.uniform(0, 1, len(self)) 

622 

623 if isinstance(mode, list): 

624 samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) 

625 else: 

626 samp = np.atleast_1d(self.rescale(vals, mode=mode)) 

627 samps[i, :] = samp 

628 

629 # check sample is in bounds (otherwise perform another draw) 

630 outbound = False 

631 for name, val in zip(self.names, samp): 

632 if val < self.bounds[name][0] or val > self.bounds[name][1]: 

633 outbound = True 

634 break 

635 

636 if not outbound: 

637 inbound = True 

638 

639 return samps 

640 

641 def _ln_prob(self, samp, lnprob, outbounds): 

642 for j in range(samp.shape[0]): 

643 # loop over the modes and sum the probabilities 

644 for i in range(self.nmodes): 

645 # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() 

646 z = (samp[j] - self.mus[i]) / self.sigmas[i] 

647 lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) 

648 

649 # set out-of-bounds values to -inf 

650 lnprob[outbounds] = -np.inf 

651 return lnprob 

652 

653 def __eq__(self, other): 

654 if self.__class__ != other.__class__: 

655 return False 

656 if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()): 

657 return False 

658 for key in self.__dict__: 

659 if key == "mvn": 

660 if len(self.__dict__[key]) != len(other.__dict__[key]): 

661 return False 

662 for thismvn, othermvn in zip(self.__dict__[key], other.__dict__[key]): 

663 if not isinstance( 

664 thismvn, scipy.stats._multivariate.multivariate_normal_frozen 

665 ) or not isinstance( 

666 othermvn, scipy.stats._multivariate.multivariate_normal_frozen 

667 ): 

668 return False 

669 elif isinstance(self.__dict__[key], (np.ndarray, list)): 

670 thisarr = np.asarray(self.__dict__[key]) 

671 otherarr = np.asarray(other.__dict__[key]) 

672 if thisarr.dtype == float and otherarr.dtype == float: 

673 fin1 = np.isfinite(np.asarray(self.__dict__[key])) 

674 fin2 = np.isfinite(np.asarray(other.__dict__[key])) 

675 if not np.array_equal(fin1, fin2): 

676 return False 

677 if not np.allclose(thisarr[fin1], otherarr[fin2], atol=1e-15): 

678 return False 

679 else: 

680 if not np.array_equal(thisarr, otherarr): 

681 return False 

682 else: 

683 if not self.__dict__[key] == other.__dict__[key]: 

684 return False 

685 return True 

686 

687 @classmethod 

688 def from_repr(cls, string): 

689 """Generate the distribution from its __repr__""" 

690 return cls._from_repr(string) 

691 

692 @classmethod 

693 def _from_repr(cls, string): 

694 subclass_args = infer_args_from_method(cls.__init__) 

695 

696 string = string.replace(" ", "") 

697 kwargs = cls._split_repr(string) 

698 for key in kwargs: 

699 val = kwargs[key] 

700 if key not in subclass_args: 

701 raise AttributeError( 

702 "Unknown argument {} for class {}".format(key, cls.__name__) 

703 ) 

704 else: 

705 kwargs[key.strip()] = Prior._parse_argument_string(val) 

706 

707 return cls(**kwargs) 

708 

709 @classmethod 

710 def _split_repr(cls, string): 

711 string = string.replace(",", ", ") 

712 # see https://stackoverflow.com/a/72146415/1862861 

713 args = re.findall(r"(\w+)=(\[.*?]|{.*?}|\S+)(?=\s*,\s*\w+=|\Z)", string) 

714 kwargs = dict() 

715 for key, arg in args: 

716 kwargs[key.strip()] = arg 

717 return kwargs 

718 

719 

720class MultivariateNormalDist(MultivariateGaussianDist): 

721 """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution.""" 

722 

723 

724class JointPrior(Prior): 

725 def __init__(self, dist, name=None, latex_label=None, unit=None): 

726 """This defines the single parameter Prior object for parameters that belong to a JointPriorDist 

727 

728 Parameters 

729 ========== 

730 dist: ChildClass of BaseJointPriorDist 

731 The shared JointPriorDistribution that this parameter belongs to 

732 name: str 

733 Name of this parameter. Must be contained in dist.names 

734 latex_label: str 

735 See superclass 

736 unit: str 

737 See superclass 

738 """ 

739 if BaseJointPriorDist not in dist.__class__.__bases__: 

740 raise TypeError( 

741 "Must supply a JointPriorDist object instance to be shared by all joint params" 

742 ) 

743 

744 if name not in dist.names: 

745 raise ValueError( 

746 "'{}' is not a parameter in the JointPriorDist".format(name) 

747 ) 

748 

749 self.dist = dist 

750 super(JointPrior, self).__init__( 

751 name=name, 

752 latex_label=latex_label, 

753 unit=unit, 

754 minimum=dist.bounds[name][0], 

755 maximum=dist.bounds[name][1], 

756 ) 

757 

758 @property 

759 def minimum(self): 

760 return self._minimum 

761 

762 @minimum.setter 

763 def minimum(self, minimum): 

764 self._minimum = minimum 

765 self.dist.bounds[self.name] = (minimum, self.dist.bounds[self.name][1]) 

766 

767 @property 

768 def maximum(self): 

769 return self._maximum 

770 

771 @maximum.setter 

772 def maximum(self, maximum): 

773 self._maximum = maximum 

774 self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) 

775 

776 def rescale(self, val, **kwargs): 

777 """ 

778 Scale a unit hypercube sample to the prior. 

779 

780 Parameters 

781 ========== 

782 val: array_like 

783 value drawn from unit hypercube to be rescaled onto the prior 

784 kwargs: dict 

785 all kwargs passed to the dist.rescale method 

786 Returns 

787 ======= 

788 float: 

789 A sample from the prior parameter. 

790 """ 

791 

792 self.dist.rescale_parameters[self.name] = val 

793 

794 if self.dist.filled_rescale(): 

795 values = np.array(list(self.dist.rescale_parameters.values())).T 

796 samples = self.dist.rescale(values, **kwargs) 

797 self.dist.reset_rescale() 

798 return samples 

799 else: 

800 return [] # return empty list 

801 

802 def sample(self, size=1, **kwargs): 

803 """ 

804 Draw a sample from the prior. 

805 

806 Parameters 

807 ========== 

808 size: int, float (defaults to 1) 

809 number of samples to draw 

810 kwargs: dict 

811 kwargs passed to the dist.sample method 

812 Returns 

813 ======= 

814 float: 

815 A sample from the prior parameter. 

816 """ 

817 

818 if self.name in self.dist.sampled_parameters: 

819 logger.warning( 

820 "You have already drawn a sample from parameter " 

821 "'{}'. The same sample will be " 

822 "returned".format(self.name) 

823 ) 

824 

825 if len(self.dist.current_sample) == 0: 

826 # generate a sample 

827 self.dist.sample(size=size, **kwargs) 

828 

829 sample = self.dist.current_sample[self.name] 

830 

831 if self.name not in self.dist.sampled_parameters: 

832 self.dist.sampled_parameters.append(self.name) 

833 

834 if len(self.dist.sampled_parameters) == len(self.dist): 

835 # reset samples 

836 self.dist.reset_sampled() 

837 self.least_recently_sampled = sample 

838 return sample 

839 

840 def ln_prob(self, val): 

841 """ 

842 Return the natural logarithm of the prior probability. Note that this 

843 will not be correctly normalised if there are bounds on the 

844 distribution. 

845 

846 Parameters 

847 ========== 

848 val: array_like 

849 value to evaluate the prior log-prob at 

850 Returns 

851 ======= 

852 float: 

853 the logp value for the prior at given sample 

854 """ 

855 self.dist.requested_parameters[self.name] = val 

856 

857 if self.dist.filled_request(): 

858 # all required parameters have been set 

859 values = list(self.dist.requested_parameters.values()) 

860 

861 # check for the same number of values for each parameter 

862 for i in range(len(self.dist) - 1): 

863 if isinstance(values[i], (list, np.ndarray)) or isinstance( 

864 values[i + 1], (list, np.ndarray) 

865 ): 

866 if isinstance(values[i], (list, np.ndarray)) and isinstance( 

867 values[i + 1], (list, np.ndarray) 

868 ): 

869 if len(values[i]) != len(values[i + 1]): 

870 raise ValueError( 

871 "Each parameter must have the same " 

872 "number of requested values." 

873 ) 

874 else: 

875 raise ValueError( 

876 "Each parameter must have the same " 

877 "number of requested values." 

878 ) 

879 

880 lnp = self.dist.ln_prob(np.asarray(values).T) 

881 

882 # reset the requested parameters 

883 self.dist.reset_request() 

884 return lnp 

885 else: 

886 # if not all parameters have been requested yet, just return 0 

887 if isinstance(val, (float, int)): 

888 return 0.0 

889 else: 

890 try: 

891 # check value has a length 

892 len(val) 

893 except Exception as e: 

894 raise TypeError("Invalid type for ln_prob: {}".format(e)) 

895 

896 if len(val) == 1: 

897 return 0.0 

898 else: 

899 return np.zeros_like(val) 

900 

901 def prob(self, val): 

902 """Return the prior probability of val 

903 

904 Parameters 

905 ========== 

906 val: array_like 

907 value to evaluate the prior prob at 

908 

909 Returns 

910 ======= 

911 float: 

912 the p value for the prior at given sample 

913 """ 

914 

915 return np.exp(self.ln_prob(val)) 

916 

917 

918class MultivariateGaussian(JointPrior): 

919 def __init__(self, dist, name=None, latex_label=None, unit=None): 

920 if not isinstance(dist, MultivariateGaussianDist): 

921 raise JointPriorDistError( 

922 "dist object must be instance of MultivariateGaussianDist" 

923 ) 

924 super(MultivariateGaussian, self).__init__( 

925 dist=dist, name=name, latex_label=latex_label, unit=unit 

926 ) 

927 

928 

929class MultivariateNormal(MultivariateGaussian): 

930 """A synonym for the :class:`bilby.core.prior.MultivariateGaussian` 

931 prior distribution.""" 

932 

933 

934class JointPriorDistError(PriorException): 

935 """Class for Error handling of JointPriorDists for JointPriors"""