Loading [MathJax]/extensions/TeX/AMSsymbols.js
LALInference 4.1.9.1-8a6b96f
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
LALInferenceGenerateROQTest.c
Go to the documentation of this file.
1#include <lal/LALInferenceGenerateROQ.h>
2#include <lal/LALConstants.h>
3#include <lal/XLALGSL.h>
4#include <gsl/gsl_randist.h>
5
6#include <time.h>
7#include <math.h>
8
9/* check whether to include omp.h for use of multiple cores */
10#ifdef HAVE_OPENMP
11#include <omp.h>
12#endif
13
14/* test waveform length */
15#define WL 1024
16
17/* number of training set waveforms */
18#define TSSIZE 1000
19
20#define TOLERANCE 10e-12
21
22/* tolerance allow for fractional percentage log likelihood difference */
23#define LTOL 0.1
24
25/* simple inspiral phase model */
26double calc_phase(double frequency, double Mchirp);
27
28/* model for a purely real frequency domain inspiral-like signal */
29double real_model(double frequency, double Mchirp, double modperiod);
30
31/* model for a complex frequency domain inspiral-like signal */
32COMPLEX16 imag_model(double frequency, double Mchirp, double modperiod);
33
34double calc_phase(double frequency, double Mchirp){
35 return (-0.25*LAL_PI + ( 3./( 128. * pow(Mchirp*LAL_MTSUN_SI*LAL_PI*frequency, 5./3.) ) ) );
36}
37
38double real_model(double frequency, double Mchirp, double modperiod){
39 return ( pow(frequency, -7./6.) * pow(Mchirp*LAL_MTSUN_SI,5./6.) * cos(calc_phase(frequency,Mchirp)) )*sin(LAL_TWOPI*frequency/modperiod);
40}
41
42COMPLEX16 imag_model(double frequency, double Mchirp, double modperiod){
43 return ( pow(frequency, -7./6.) * pow(Mchirp*LAL_MTSUN_SI,5./6.) * cexp(I*calc_phase(frequency,Mchirp)) )*sin(LAL_TWOPI*frequency/modperiod);
44}
45
46int main(void) {
47 REAL8Array *TS = NULL, *TSquad = NULL, *cTSquad = NULL; /* the training set of real waveforms (and quadratic model) */
48 COMPLEX16Array *cTS = NULL; /* the training set of complex waveforms */
49 UINT4Vector *gdpts = NULL; /* the greedy points used for the reduced basis generation */
50
51 size_t TSsize; /* the size of the training set (number of waveforms) */
52 size_t wl; /* the length of each waveform */
53 size_t k = 0, j = 0, i = 0;
54
55 REAL8Array *RBlinear = NULL, *RBquad = NULL, *cRBquad = NULL; /* the real reduced basis set */
56 COMPLEX16Array *cRBlinear = NULL; /* the complex reduced basis set */
57
58 LALInferenceREALROQInterpolant *interp = NULL, *interpQuad = NULL, *cinterpQuad = NULL;
60
61 double tolerance = TOLERANCE; /* tolerance for reduced basis generation loop */
62
63 TSsize = TSSIZE;
64 wl = WL;
65
66 /* allocate memory for training set */
68 TSdims->data[0] = TSsize;
69 TSdims->data[1] = wl;
70 TS = XLALCreateREAL8Array( TSdims );
71 TSquad = XLALCreateREAL8Array( TSdims );
72 cTS = XLALCreateCOMPLEX16Array( TSdims );
73 cTSquad = XLALCreateREAL8Array( TSdims );
74
75 gsl_matrix_view TSview, TSviewquad, cTSviewquad;
76 TSview = gsl_matrix_view_array( TS->data, TSdims->data[0], TSdims->data[1] );
77 TSviewquad = gsl_matrix_view_array( TSquad->data, TSdims->data[0], TSdims->data[1] );
78 cTSviewquad = gsl_matrix_view_array( cTSquad->data, TSdims->data[0], TSdims->data[1] );
79 gsl_matrix_complex_view cTSview;
80 cTSview = gsl_matrix_complex_view_array( (double *)cTS->data, TSdims->data[0], TSdims->data[1] );
81
82 XLALDestroyUINT4Vector( TSdims );
83
84 /* the waveform model is just a simple chirp so set up chirp mass range for training set */
85 double fmin0 = 48, fmax0 = 256, f0 = 0., m0 = 0.;
86 double df = (fmax0-fmin0)/(wl-1.); /* model time steps */
87 double Mcmax = 2., Mcmin = 1.5, Mc = 0.;
88 double periodmax = 1./99.995, periodmin = 1./100., modperiod = 0.;
89
90 REAL8Vector *fweights = XLALCreateREAL8Vector( 1 );
91 fweights->data[0] = df;
92
94
95 /* random number generator setup */
96 const gsl_rng_type *T;
97 gsl_rng *r;
98 gsl_rng_env_setup();
99 T = gsl_rng_default;
100 r = gsl_rng_alloc(T);
101
102 /* set up training sets (one real and one complex) */
103 for ( k=0; k < TSsize; k++ ){
104 Mc = pow(pow(Mcmin, 5./3.) + (double)k*(pow(Mcmax, 5./3.)-pow(Mcmin, 5./3.))/((double)TSsize-1), 3./5.);
105 modperiod = gsl_ran_flat(r, periodmin, periodmax);
106
107 for ( j=0; j < wl; j++ ){
108 f0 = fmin0 + (double)j*(fmax0-fmin0)/((double)wl-1.);
109
110 gsl_complex gctmp;
111 COMPLEX16 ctmp;
112 m0 = real_model(f0, Mc, modperiod);
113 ctmp = imag_model(f0, Mc, modperiod);
114 GSL_SET_COMPLEX(&gctmp, creal(ctmp), cimag(ctmp));
115 freqs->data[j] = f0;
116 gsl_matrix_set(&TSview.matrix, k, j, m0);
117 gsl_matrix_set(&TSviewquad.matrix, k, j, m0*m0);
118 gsl_matrix_complex_set(&cTSview.matrix, k, j, gctmp);
119 gsl_matrix_set(&cTSviewquad.matrix, k, j, creal(ctmp*conj(ctmp)));
120 }
121 }
122
123 /* create reduced orthonormal basis from training set for linear part */
124 REAL8 maxprojerr = 0.;
125 maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBlinear, fweights, tolerance, &TS, &gdpts);
126 XLALDestroyUINT4Vector( gdpts );
127 fprintf(stderr, "No. linear nodes (real) = %d, %d x %d; Maximum projection err. = %le\n", RBlinear->dimLength->data[0], RBlinear->dimLength->data[0], RBlinear->dimLength->data[1], maxprojerr);
128 maxprojerr = LALInferenceGenerateCOMPLEX16OrthonormalBasis(&cRBlinear, fweights, tolerance, &cTS, &gdpts);
129 XLALDestroyUINT4Vector( gdpts );
130 fprintf(stderr, "No. linear nodes (complex) = %d, %d x %d; Maximum projection err. = %le\n", cRBlinear->dimLength->data[0], cRBlinear->dimLength->data[0], cRBlinear->dimLength->data[1], maxprojerr);
131 maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&RBquad, fweights, tolerance, &TSquad, &gdpts);
132 XLALDestroyUINT4Vector( gdpts );
133 fprintf(stderr, "No. quadratic nodes (real) = %d, %d x %d; Maximum projection err. = %le\n", RBquad->dimLength->data[0], RBquad->dimLength->data[0], RBquad->dimLength->data[1], maxprojerr);
134 maxprojerr = LALInferenceGenerateREAL8OrthonormalBasis(&cRBquad, fweights, tolerance, &cTSquad, &gdpts);
135 XLALDestroyUINT4Vector( gdpts );
136 fprintf(stderr, "No. quadratic nodes (complex) = %d, %d x %d; Maximum projection err. = %le\n", cRBquad->dimLength->data[0], cRBquad->dimLength->data[0], cRBquad->dimLength->data[1], maxprojerr);
137
138 /* free the training set */
141 XLALDestroyREAL8Array( TSquad );
142 XLALDestroyREAL8Array( cTSquad );
143 XLALDestroyREAL8Vector( fweights );
144
145 /* get the linear interpolant */
148
149 /* get the quadratic interpolant */
150 interpQuad = LALInferenceGenerateREALROQInterpolant(RBquad);
151 cinterpQuad = LALInferenceGenerateREALROQInterpolant(cRBquad);
152
153 /* free the reduced basis */
154 XLALDestroyREAL8Array(RBlinear);
155 XLALDestroyCOMPLEX16Array(cRBlinear);
156 XLALDestroyREAL8Array(RBquad);
157 XLALDestroyREAL8Array(cRBquad);
158
159 /* now get the terms for the likelihood with and without the reduced order quadrature
160 * and do some timing tests */
161
162 /* create the model dot model weights */
164 vars->data[0] = 1.;
165
166 REAL8Vector *mmw = LALInferenceGenerateQuadraticWeights(interpQuad->B, vars);
167 REAL8Vector *cmmw = LALInferenceGenerateQuadraticWeights(cinterpQuad->B, vars);
168
169 /* let's create some Gaussian random data */
172 for ( i=0; i<wl; i++ ){
173 data->data[i] = gsl_ran_gaussian(r, 1.0); /* real data */
174 cdata->data[i] = gsl_ran_gaussian(r, 1.0) + I*gsl_ran_gaussian(r, 1.0); /* complex data */
175 }
176
177 /* create the data dot model weights */
179 COMPLEX16Vector *cdmw = LALInferenceGenerateCOMPLEX16LinearWeights(cinterp->B, cdata, vars);
180
182
183 /* pick a chirp mass and generate a model to compare likelihoods */
184 double randMc = 1.873; /* a random frequency to create a model */
185 double randmp = 1./99.9989;
186
187 gsl_vector *modelfull = gsl_vector_alloc(wl);
188 REAL8Vector *modelreduced = XLALCreateREAL8Vector( interp->B->dimLength->data[0] );
189 REAL8Vector *modelreducedquad = XLALCreateREAL8Vector( interpQuad->B->dimLength->data[0] );
190 gsl_vector_complex *cmodelfull = gsl_vector_complex_alloc(wl);
191 COMPLEX16Vector *cmodelreduced = XLALCreateCOMPLEX16Vector( cinterp->B->dimLength->data[0] );
192 REAL8Vector *cmodelreducedquad = XLALCreateREAL8Vector( cinterpQuad->B->dimLength->data[0] );
193
194 /* create models */
195 for ( i=0; i<wl; i++ ){
196 /* models at all frequencies */
197 gsl_vector_set(modelfull, i, real_model(freqs->data[i], randMc, randmp));
198
199 COMPLEX16 cval = imag_model(freqs->data[i], randMc, randmp);
200 gsl_complex gcval;
201 GSL_SET_COMPLEX(&gcval, creal(cval), cimag(cval));
202 gsl_vector_complex_set(cmodelfull, i, gcval);
203 }
204
205 /* models at interpolant nodes */
206 for ( i=0; i<modelreduced->length; i++ ){ /* real model */
207 REAL8 rm = real_model(freqs->data[interp->nodes[i]], randMc, randmp);
208 modelreduced->data[i] = rm;
209 }
210 for ( i=0; i<modelreducedquad->length; i++ ){ /* real model */
211 REAL8 rm = real_model(freqs->data[interpQuad->nodes[i]], randMc, randmp);
212 modelreducedquad->data[i] = rm*rm;
213 }
214 for ( i=0; i<cmodelreduced->length; i++ ){ /* complex model */
215 COMPLEX16 crm = imag_model(freqs->data[cinterp->nodes[i]], randMc, randmp);
216 cmodelreduced->data[i] = crm;
217 }
218 for ( i=0; i<cmodelreducedquad->length; i++ ){ /* complex model */
219 COMPLEX16 crm = imag_model(freqs->data[cinterpQuad->nodes[i]], randMc, randmp);
220 cmodelreducedquad->data[i] = creal(crm*conj(crm));
221 }
222
223 XLALDestroyREAL8Vector( freqs );
224
225 /* timing variables */
226 struct timeval t1, t2, t3, t4;
227 double dt1, dt2;
228
229 /* start with the real model */
230 /* get the model model term with the full model */
231 REAL8 mmfull, mmred;
232 gettimeofday(&t1, NULL);
233 XLAL_CALLGSL( gsl_blas_ddot(modelfull, modelfull, &mmfull) ); /* real model */
234 gettimeofday(&t2, NULL);
235
236 /* now get it with the reduced order quadrature */
237 gettimeofday(&t3, NULL);
238 mmred = LALInferenceROQREAL8DotProduct(mmw, modelreducedquad);
239 gettimeofday(&t4, NULL);
240
241 dt1 = (double)((t2.tv_sec + t2.tv_usec*1.e-6) - (t1.tv_sec + t1.tv_usec*1.e-6));
242 dt2 = (double)((t4.tv_sec + t4.tv_usec*1.e-6) - (t3.tv_sec + t3.tv_usec*1.e-6));
243 fprintf(stderr, "Real Signal:\n - M dot M (full) = %le [%.9lf s], M dot M (reduced) = %le [%.9lf s], time ratio = %lf\n", mmfull, dt1, mmred, dt2, dt1/dt2);
244
245 /* get the data model term with the full model */
246 REAL8 dmfull, dmred;
247 gsl_vector_view dataview = gsl_vector_view_array(data->data, wl);
248 gettimeofday(&t1, NULL);
249 XLAL_CALLGSL( gsl_blas_ddot(&dataview.vector, modelfull, &dmfull) );
250 gettimeofday(&t2, NULL);
251
252 /* now get it with the reduced order quadrature */
253 gettimeofday(&t3, NULL);
254 dmred = LALInferenceROQREAL8DotProduct(dmw, modelreduced);
255 gettimeofday(&t4, NULL);
256
257 dt1 = (double)((t2.tv_sec + t2.tv_usec*1.e-6) - (t1.tv_sec + t1.tv_usec*1.e-6));
258 dt2 = (double)((t4.tv_sec + t4.tv_usec*1.e-6) - (t3.tv_sec + t3.tv_usec*1.e-6));
259 fprintf(stderr, " - D dot M (full) = %le [%.9lf s], D dot M (reduced) = %le [%.9lf s], time ratio = %lf\n", dmfull, dt1, dmred, dt2, dt1/dt2);
260
261 /* check difference in log likelihoods */
262 double Lfull, Lred, Lfrac;
263
264 Lfull = mmfull - 2.*dmfull;
265 Lred = mmred - 2.*dmred;
266 Lfrac = 100.*fabs(Lfull-Lred)/fabs(Lfull); /* fractional log likelihood difference (in %) */
267
268 fprintf(stderr, " - Fractional difference in log likelihoods = %lf%%\n", Lfrac);
269
271 gsl_vector_free(modelfull);
272 XLALDestroyREAL8Vector(modelreduced);
273 XLALDestroyREAL8Vector(modelreducedquad);
276
277 /* check log likelihood difference is within tolerance */
278 if ( Lfrac > LTOL ) { return 1; }
279
280 /* now do the same with the complex model */
281 /* get the model model term with the full model */
282 REAL8 cmmred, cmmfull;
283 gsl_complex cmmfulltmp;
284 gettimeofday(&t1, NULL);
285 XLAL_CALLGSL( gsl_blas_zdotc(cmodelfull, cmodelfull, &cmmfulltmp) ); /* complex model */
286 cmmfull = GSL_REAL(cmmfulltmp);
287 gettimeofday(&t2, NULL);
288
289 gettimeofday(&t3, NULL);
290 cmmred = LALInferenceROQREAL8DotProduct(cmmw, cmodelreducedquad);
291 gettimeofday(&t4, NULL);
292 dt1 = (double)((t2.tv_sec + t2.tv_usec*1.e-6) - (t1.tv_sec + t1.tv_usec*1.e-6));
293 dt2 = (double)((t4.tv_sec + t4.tv_usec*1.e-6) - (t3.tv_sec + t3.tv_usec*1.e-6));
294 fprintf(stderr, "Complex Signal:\n - M dot M (full) = %le [%.9lf s], M dot M (reduced) = %le [%.9lf s], time ratio = %lf\n", cmmfull, dt1, cmmred, dt2, dt1/dt2);
295
296 COMPLEX16 cdmfull, cdmred;
297 gsl_complex cdmfulltmp;
298 gsl_vector_complex_view cdataview = gsl_vector_complex_view_array((double*)cdata->data, wl);
299 gettimeofday(&t1, NULL);
300 XLAL_CALLGSL( gsl_blas_zdotc(&cdataview.vector, cmodelfull, &cdmfulltmp) );
301 cdmfull = GSL_REAL(cdmfulltmp) + I*GSL_IMAG(cdmfulltmp);
302 gettimeofday(&t2, NULL);
303
304 gettimeofday(&t3, NULL);
305 cdmred = LALInferenceROQCOMPLEX16DotProduct(cdmw, cmodelreduced);
306 gettimeofday(&t4, NULL);
307
308 dt1 = (double)((t2.tv_sec + t2.tv_usec*1.e-6) - (t1.tv_sec + t1.tv_usec*1.e-6));
309 dt2 = (double)((t4.tv_sec + t4.tv_usec*1.e-6) - (t3.tv_sec + t3.tv_usec*1.e-6));
310 fprintf(stderr, " - D dot M (full) = %le [%.9lf s], D dot M (reduced) = %le [%.9lf s], time ratio = %lf\n", creal(cdmfull), dt1, creal(cdmred), dt2, dt1/dt2);
311
312 /* check difference in log likelihoods */
313 Lfull = cmmfull - 2.*creal(cdmfull);
314 Lred = cmmred - 2.*creal(cdmred);
315 Lfrac = 100.*fabs(Lfull-Lred)/fabs(Lfull); /* fractional log likelihood difference (in %) */
316
317 fprintf(stderr, " - Fractional difference in log likelihoods = %lf%%\n", Lfrac);
318
320 gsl_vector_complex_free(cmodelfull);
321 XLALDestroyCOMPLEX16Vector(cmodelreduced);
322 XLALDestroyREAL8Vector(cmodelreducedquad);
325
330
331 /* check log likelihood difference is within tolerance */
332 if ( Lfrac > LTOL ) { return 1; }
333
335
336 return 0;
337}
COMPLEX16 LALInferenceROQCOMPLEX16DotProduct(COMPLEX16Vector *weights, COMPLEX16Vector *model)
Calculate the dot product of two complex vectors using the ROQ iterpolant.
REAL8 LALInferenceGenerateCOMPLEX16OrthonormalBasis(COMPLEX16Array **RBin, const REAL8Vector *delta, REAL8 tolerance, COMPLEX16Array **TS, UINT4Vector **greedypoints)
Create a orthonormal basis set from a training set of complex waveforms.
void LALInferenceRemoveCOMPLEXROQInterpolant(LALInferenceCOMPLEXROQInterpolant *a)
Free memory for a LALInferenceCOMPLEXROQInterpolant.
COMPLEX16Vector * LALInferenceGenerateCOMPLEX16LinearWeights(COMPLEX16Array *B, COMPLEX16Vector *data, REAL8Vector *vars)
Create the weights for the ROQ interpolant for the complex data and model dot product <d|h>
REAL8Vector * LALInferenceGenerateREAL8LinearWeights(REAL8Array *B, REAL8Vector *data, REAL8Vector *vars)
Create the weights for the ROQ interpolant for the linear data and model dot product <d|h>
REAL8 LALInferenceROQREAL8DotProduct(REAL8Vector *weights, REAL8Vector *model)
Calculate the dot product of two vectors using the ROQ iterpolant.
REAL8Vector * LALInferenceGenerateQuadraticWeights(REAL8Array *B, REAL8Vector *vars)
Create the weights for the ROQ interpolant for the model quadratic model term real(<h|h>).
LALInferenceCOMPLEXROQInterpolant * LALInferenceGenerateCOMPLEXROQInterpolant(COMPLEX16Array *RB)
Create a complex empirical interpolant from a set of orthonormal basis functions.
REAL8 LALInferenceGenerateREAL8OrthonormalBasis(REAL8Array **RBin, const REAL8Vector *delta, REAL8 tolerance, REAL8Array **TS, UINT4Vector **greedypoints)
Create a orthonormal basis set from a training set of real waveforms.
void LALInferenceRemoveREALROQInterpolant(LALInferenceREALROQInterpolant *a)
Free memory for a LALInferenceREALROQInterpolant.
LALInferenceREALROQInterpolant * LALInferenceGenerateREALROQInterpolant(REAL8Array *RB)
Create a real empirical interpolant from a set of orthonormal basis functions.
double real_model(double frequency, double Mchirp, double modperiod)
#define TOLERANCE
#define TSSIZE
COMPLEX16 imag_model(double frequency, double Mchirp, double modperiod)
int main(void)
#define LTOL
double calc_phase(double frequency, double Mchirp)
int j
int k
void LALCheckMemoryLeaks(void)
#define fprintf
#define XLAL_CALLGSL(statement)
sigmaKerr data[0]
void XLALDestroyREAL8Array(REAL8Array *)
COMPLEX16Array * XLALCreateCOMPLEX16Array(UINT4Vector *)
void XLALDestroyCOMPLEX16Array(COMPLEX16Array *)
REAL8Array * XLALCreateREAL8Array(UINT4Vector *)
#define LAL_PI
#define LAL_TWOPI
#define LAL_MTSUN_SI
double complex COMPLEX16
double REAL8
static const INT4 r
void XLALDestroyUINT4Vector(UINT4Vector *vector)
REAL8Vector * XLALCreateREAL8Vector(UINT4 length)
COMPLEX16Vector * XLALCreateCOMPLEX16Vector(UINT4 length)
void XLALDestroyREAL8Vector(REAL8Vector *vector)
void XLALDestroyCOMPLEX16Vector(COMPLEX16Vector *vector)
UINT4Vector * XLALCreateUINT4Vector(UINT4 length)
UINT4Vector * dimLength
COMPLEX16 * data
COMPLEX16 * data
A structure to hold a complex (double precision) interpolant matrix and interpolation node indices.
COMPLEX16Array * B
The interpolant matrix.
UINT4 * nodes
The nodes (indices) for the interpolation.
A structure to hold a real (double precision) interpolant matrix and interpolation node indices.
UINT4 * nodes
The nodes (indices) for the interpolation.
REAL8Array * B
The interpolant matrix.
UINT4Vector * dimLength
REAL8 * data
REAL8 * data
UINT4 * data
double df