Loading [MathJax]/extensions/TeX/AMSmath.js
LALInference 4.1.9.1-3a66518
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
cbcBayesCombinePosteriors.py
Go to the documentation of this file.
1##python
2# -*- coding: utf-8 -*-
3# cbcBayesCombinePosteriors.py
4#
5# Copyright 2016
6# Christopher Berry <christopher.berry@ligo.org>
7# Sebastian Gaebel <sebastian.gaebel@ligo.org>
8#
9# This program is free software; you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation; either version 2 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program; if not, write to the Free Software
21# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
22# MA 02110-1301, USA.
23
24"""
25This is used for combining posterior samples together.
26"""
27
28#Import standard things
29import argparse
30try:
31 from lalinference import LALInferenceHDF5PosteriorSamplesDatasetName as posterior_grp_name
32except:
33 posterior_grp_name = "posterior_samples"
34from collections import defaultdict
35import h5py
36import numpy as np
37
38#Set-up commands
39parser = argparse.ArgumentParser(description="Combine some posterior samples.")
40parser.add_argument("-o", "--output", dest="outfilename", default="combined_posterior",
41 help="Write combined posterior to OUTFILE", metavar="OUTFILE")
42parser.add_argument("-p", "--pos", action="append", dest="infilename",
43 help="Combine posteriors from INFILE", metavar="INFILE", required=True)
44parser.add_argument("-a", "--all", dest="all",
45 action="store_true", default=False,
46 help="Use all posterior samples (do not weight)")
47parser.add_argument("-w", "--weight", action="append", default=[], dest="weightings",
48 help="Weighting for posteriors", type=float)
49shuffleGroup = parser.add_mutually_exclusive_group()
50shuffleGroup.add_argument("-s", "--shuffle",
51 action="store_true", dest="shuffle", default=True,
52 help="Randomise posterior samples before combination [Default]")
53shuffleGroup.add_argument("-ns", "--no-shuffle",
54 action="store_false", dest="shuffle",
55 help="Do not randomise posterior samples before combination")
56parser.add_argument("-m", "--mix", dest="mix",
57 action="store_true", default=False,
58 help="Randomise combined samples")
59fileGroup = parser.add_mutually_exclusive_group()
60fileGroup.add_argument("-t", "--text",
61 action="store_false", dest="hdf", default=False,
62 help="Use ASCII posterior (.dat) files [Default]")
63fileGroup.add_argument("-f", "--hdf",
64 action="store_true", dest="hdf",
65 help="Use HDF5 posterior files")
66
67
68args = parser.parse_args()
69
70#Count arguments
71nPos = np.size(args.infilename)
72nWeight = np.size(args.weightings)
73
74print(args.weightings)
75
76#Check sensible combination of arguments
77if (nWeight != 0):
78 if args.all:
79 print("You cannot use all posterior samples and weight them!")
80 exit(1)
81
82 if (nWeight != nPos):
83 print("Please either specify a weight for each posterior file or none")
84 exit(1)
85else:
86 args.weightings = [1.0] * nPos
87
88#Specify combination ID
89combineID = "combined"
90if args.all:
91 combineID = combineID+"_all"
92else:
93 combineID = combineID+"_weight_"+'_'.join(map(str, args.weightings))
94
95if args.shuffle:
96 combineID = combineID+"_shuffle"
97else:
98 combineID = combineID+"_noshuffle"
99
100if args.mix:
101 combineID = combineID+"_mixed"
102
103print("Combined ID:", combineID)
104
105#Initiate lists to hold data
106samples = []
107paramsList = []
108sizeList = []
109metadata = {"lalinference": defaultdict(lambda: [None]*nPos),
110 "lalinference/"+combineID: defaultdict(lambda: [None]*nPos),
111 "lalinference/"+combineID+"/"+posterior_grp_name: defaultdict(lambda: [None]*nPos)}
112
113#Read in data
114for posIndex in range(nPos):
115 if args.hdf:
116 #HDF5 files with metadata
117 with h5py.File(args.infilename[posIndex], "r") as inFile:
118
119 group = inFile["lalinference"]
120 for key in group.attrs:
121 metadata["lalinference"][key][posIndex] = group.attrs[key]
122
123 run_id = list(group.keys())[0]
124
125 group = group[run_id]
126
127 for key in group.attrs:
128 metadata["lalinference/"+combineID][key][posIndex] = group.attrs[key]
129
130 if "combined_run_ids" not in group.attrs:
131 metadata["lalinference/"+combineID]["combined_run_ids"][posIndex] = run_id
132 elif "recombined_run_ids" not in group.attrs:
133 metadata["lalinference/"+combineID]["recombined_run_ids"][posIndex] = run_id
134 elif "rerecombined_run_ids" not in group.attrs:
135 metadata["lalinference/"+combineID]["rerecombined_run_ids"][posIndex] = run_id
136 elif "rererecombined_run_ids" not in group.attrs:
137 metadata["lalinference/"+combineID]["rererecombined_run_ids"][posIndex] = True
138 else:
139 print("Too many combinations to count!")
140
141
142 group = group[posterior_grp_name]
143 for key in group.attrs:
144 metadata["lalinference/"+combineID+posterior_grp_name][key][posIndex] = group.attrs[key]
145
146 posDtype = []
147 for key in group:
148 posDtype.append((key, group[key].dtype))
149 shape = group[key].shape
150
151 posData = np.empty(shape, dtype=posDtype)
152
153 for key in group:
154 posData[key] = group[key][:]
155
156 else:
157 #Standard text file
158 posData = np.genfromtxt(args.infilename[posIndex], names=True)
159
160 if (args.shuffle):
161 np.random.shuffle(posData)
162
163 samples.append(posData)
164 paramsList.append(set(posData.dtype.names))
165 sizeList.append(np.size(posData))
166
167
168#Create intersection
169paramsOut = list(set.intersection(*paramsList))
170
171datatypes = samples[0][paramsOut].dtype
172
173#Combine posteriors
174if (args.all):
175 #Use all samples
176 sizeOut = sum(sizeList)
177 samplesOut = np.empty(sizeOut, dtype=datatypes)
178
179 indexSize = sizeList
180
181else:
182 #Weight different posteriors
183 fracWeight = np.asarray(args.weightings) / float(sum(args.weightings))
184
185 testNum = fracWeight * float(sum(sizeList))
186 minIndex = np.argmin(np.asarray(sizeList) / np.asarray(testNum))
187
188 testSize = sizeList[minIndex] / fracWeight[minIndex]
189
190 weightNum = np.around(fracWeight * testSize).astype(int)
191 sizeOut = sum(weightNum)
192 samplesOut = np.empty(sizeOut, dtype=datatypes)
193
194 indexSize = weightNum
195
196
197print("Using number of samples ", indexSize)
198
199
200startIndex = 0
201for posIndex in range(0,nPos):
202 stopIndex = startIndex + indexSize[posIndex]
203
204 for paramIndex, paramItem in enumerate(paramsOut):
205 samplesOut[paramItem][startIndex:stopIndex] = samples[posIndex][paramItem][0:indexSize[posIndex]]
206
207 startIndex = stopIndex
208
209
210#Mix samples
211if args.mix:
212 np.random.shuffle(samplesOut)
213
214
215#Save output
216if args.hdf:
217 #HDF5 file with metadata
218 with h5py.File(path, "w") as outFile:
219 group = outFile.create_group("lalinference")
220 group = group.create_group(combineID)
221 group = group.create_group(posterior_grp_name)
222 for key in samplesOut.dtype.names:
223 group.create_dataset(key, data=samplesOut[key], shuffle=True, compression="gzip")
224
225 for level in metadata:
226 for key in metadata[level]:
227 outFile[level].attrs[key] = metadata[level][key]
228
229else:
230 #Standard textt output
231 paramHeader = "\t".join(paramsOut)
232 np.savetxt(args.outfilename, samplesOut.T, delimiter="\t", header=paramHeader, comments="")
233
234
235#Done!