Loading [MathJax]/extensions/TeX/AMSsymbols.js
LALInference 4.1.9.1-5e288d3
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
wrapper.py
Go to the documentation of this file.
1import lal
2import lalinference as li
4
5class LIVariablesWrap(collections.abc.MutableMapping):
6 def __init__(self,init=None):
7 """
8 Wrapper to present a LALInferenceVariable as a dict.
9
10 Parameters
11 ----------
12 init : dict
13 Initialise with the given dictionary.
14 If init is itself a LALInferenceVariables C struct
15 then the wrapper will wrap around it but not reallocate memory
16 """
17 self.owner=True # Python should manage the object's memory
18 if isinstance(init,li.Variables):
19 self.v=init
20 self.owner=False
21 else:
22 self.v=li.Variables()
23 if init:
24 self.update(init)
25 def __delitem__(self,key):
26 if li.CheckVariable(self.v, key):
27 li.RemoveVariable(self.v, key)
28 else:
29 raise KeyError(key)
30 def __setitem__(self, key, value):
31 if type(value)==float:
32 li.AddREAL8Variable(self.v, key, value, li.LALINFERENCE_PARAM_LINEAR)
33 elif type(value)==int:
34 li.AddINT4Variable(self.v, key, value, li.LAINFERENCE_PARAM_LINEAR)
35 else:
36 raise TypeError('Unsupported type: ',key, self.type(key))
37 def __getitem__(self, key):
38 if li.CheckVariable(self.v, key):
39 if self.type(key)==li.LALINFERENCE_REAL8_t:
40 return li.GetREAL8Variable(self.v, key)
41 elif self.type(key)==li.LALINFERENCE_INT4_t:
42 return li.GetINT4Variable(self.v, key)
43 elif self.type(key)==li.LALINFERENCE_UINT4_t:
44 return li.GetUINT4Variable(self.v, key)
45 else:
46 raise(TypeError('Unsupported type: ',key,self.type(key)))
47 else:
48 raise KeyError(key)
49 def __iter__(self):
50 return _variterator(self.v)
51 def __len__(self):
52 return self.v.dimension
53 def __del__(self):
54 if self.owner:
55 li.ClearVariables(self.v)
56 def __repr__(self):
57 return 'LIVariablesWrap('+repr(dict(self))+')'
58 def __str__(self):
59 return str(dict(self))
60 def varyType(self, key):
61 """
62 Return the lalinference variable's varyType
63
64 Parameters
65 ----------
66 key : str
67 The name of the variable to look up
68
69 Returns
70 -------
71 varytype : lalinference.varyType (e.g. lalinference.LALINFERENCE_PARAM_FIXED)
72 """
73 if not li.CheckVariable(self.v, key):
74 raise KeyError(key)
75 return li.GetVariableVaryType(self.v, key)
76 def type(self, key):
77 """
78 Return the lalinference variable's varyType
79
80 Parameters
81 ----------
82 key : str
83 The name of the variable to look up
84
85 Returns
86 -------
87 type : the LALInference type (e.g. lalinference.LALINFERENCE_REAL8_t)
88 """
89 if not li.CheckVariable(self.v, key):
90 raise KeyError(key)
91 return li.GetVariableType(self.v, key)
92
93class _variterator(object):
94 def __init__(self, var):
95 self.varitem = var.head
96
97 def __iter__(self):
98 return self
99
100 def __next__(self):
101 if not self.varitem:
102 raise StopIteration
103 else:
104 this = self.varitem
105 self.varitem=self.varitem.next
106 return(this.name)
107
108 def next(self):
109 return self.__next__()
110
111
112class LALInferenceCBCWrapper(object):
113 """
114 Class to wrap a LALInference CBC analysis
115 state, and expose the likelihood and prior
116 methods to python programs
117 """
118 def __init__(self, argv):
119 """
120 Parameters
121 ----------
122 argv : list
123 List of command line arguments that will be used to
124 set up the lalinference state. (similar to argv)
125 """
126 strvec = lal.CreateStringVector(argv[0])
127 for a in argv[1:]:
128 strvec=lal.AppendString2Vector(strvec, a)
129 procParams=li.ParseStringVector(strvec)
130 self.state = li.InitRunState(procParams)
131 self.state.commandLine=procParams
132 li.InitCBCThreads(self.state,1)
133
134 # This is what Nest does
135 li.InjectInspiralSignal(self.state.data, self.state.commandLine)
136 li.ApplyCalibrationErrors(self.state.data, procParams)
137 if li.GetProcParamVal(procParams,'--roqtime_steps'):
138 li.SetupROQdata(self.state.data, procParams)
139 li.InitCBCPrior(self.state)
140 li.InitLikelihood(self.state)
141 li.InitCBCThreads(self.state,1)
142
143
144 def log_likelihood(self,params):
145 """
146 Log-likelihood function from LALInference
147
148 Parameters
149 ----------
150 params : dict
151 Dict-like object of sampling parameters, will
152 be automatically converted for lalinference
153
154 Returns
155 -------
156 logL : float
157 log-likelihood value
158 """
159 # Pick up non-sampled vars
160 liv = LIVariablesWrap(self.state.threads.currentParams)
161 # Update with proposed values
162 liv.update(params)
163 self.state.threads.model.currentParams=liv.v
164 return li.MarginalisedPhaseLogLikelihood(liv.v, self.state.data, self.state.threads.model)
165
166 def log_prior(self,params):
167 """
168 Log-prior function from LALInference
169
170 Parameters
171 ----------
172 params : dict
173 Dict-like object of sampling parameters, will
174 be automatically converted for lalinference
175
176 Returns
177 -------
178 logPr : float
179 log-prior value
180 """
181 # Pick up non-sampled vars
182 liv = LIVariablesWrap(self.state.threads.currentParams)
183 # Update with proposed values
184 liv.update(params)
185 return li.InspiralPrior(self.state, liv.v, self.state.threads.model)
186
187 def params(self):
188 """
189 Parameter names from the LALInference model. Includes
190 those which are fixed
191
192 Returns
193 names : list
194 A list of parameter names
195 """
196 LIV=LIVariablesWrap(self.state.threads.currentParams)
197 return LIV.keys()
198
199 def sampling_params(self):
200 """
201 Parameter names from the LALInference model. Includes
202 only those which are varied in the sampling.
203
204 Returns
205 names : list
206 A list of parameter names
207 """
208 pars = LIVariablesWrap(self.state.threads.currentParams)
209 return [p for p in pars if pars.varyType(p)==li.LALINFERENCE_PARAM_LINEAR
210 or pars.varyType(p)==li.LALINFERENCE_PARAM_CIRCULAR
211 ]
212
213 def prior_bounds(self):
214 """
215 Bounds of the sampling parameters.
216
217 Returns
218 bounds : dict
219 A dict of (low,high) pairs, indexed by parameter name
220 e.g. {'declination' : (0, 3.14159), ...}
221 """
222 bounds={}
223 libounds = LIVariablesWrap(self.state.priorArgs)
224 for p in self.sampling_params():
225 try:
226 low = libounds[p+'_min']
227 high = libounds[p+'_max']
228 bounds[p]=(low, high)
229 except KeyError:
230 pass
231 return bounds
Class to wrap a LALInference CBC analysis state, and expose the likelihood and prior methods to pytho...
Definition: wrapper.py:117
def log_prior(self, params)
Log-prior function from LALInference.
Definition: wrapper.py:180
def prior_bounds(self)
Bounds of the sampling parameters.
Definition: wrapper.py:221
def params(self)
Parameter names from the LALInference model.
Definition: wrapper.py:195
def sampling_params(self)
Parameter names from the LALInference model.
Definition: wrapper.py:207
def log_likelihood(self, params)
Log-likelihood function from LALInference.
Definition: wrapper.py:158
def varyType(self, key)
Return the lalinference variable's varyType.
Definition: wrapper.py:72
def type(self, key)
Return the lalinference variable's varyType.
Definition: wrapper.py:88
def __setitem__(self, key, value)
Definition: wrapper.py:30
def __init__(self, init=None)
Wrapper to present a LALInferenceVariable as a dict.
Definition: wrapper.py:16