LALInference  4.1.6.1-89842e6
cbcBayesCombinePosteriors.py
Go to the documentation of this file.
1 # -*- coding: utf-8 -*-
2 # cbcBayesCombinePosteriors.py
3 #
4 # Copyright 2016
5 # Christopher Berry <christopher.berry@ligo.org>
6 # Sebastian Gaebel <sebastian.gaebel@ligo.org>
7 #
8 # This program is free software; you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation; either version 2 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the GNU General Public License
19 # along with this program; if not, write to the Free Software
20 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
21 # MA 02110-1301, USA.
22 
23 """
24 This is used for combining posterior samples together.
25 """
26 
27 #Import standard things
28 import argparse
29 try:
30  from lalinference import LALInferenceHDF5PosteriorSamplesDatasetName as posterior_grp_name
31 except:
32  posterior_grp_name = "posterior_samples"
33 from collections import defaultdict
34 import h5py
35 import numpy as np
36 
37 #Set-up commands
38 parser = argparse.ArgumentParser(description="Combine some posterior samples.")
39 parser.add_argument("-o", "--output", dest="outfilename", default="combined_posterior",
40  help="Write combined posterior to OUTFILE", metavar="OUTFILE")
41 parser.add_argument("-p", "--pos", action="append", dest="infilename",
42  help="Combine posteriors from INFILE", metavar="INFILE", required=True)
43 parser.add_argument("-a", "--all", dest="all",
44  action="store_true", default=False,
45  help="Use all posterior samples (do not weight)")
46 parser.add_argument("-w", "--weight", action="append", default=[], dest="weightings",
47  help="Weighting for posteriors", type=float)
48 shuffleGroup = parser.add_mutually_exclusive_group()
49 shuffleGroup.add_argument("-s", "--shuffle",
50  action="store_true", dest="shuffle", default=True,
51  help="Randomise posterior samples before combination [Default]")
52 shuffleGroup.add_argument("-ns", "--no-shuffle",
53  action="store_false", dest="shuffle",
54  help="Do not randomise posterior samples before combination")
55 parser.add_argument("-m", "--mix", dest="mix",
56  action="store_true", default=False,
57  help="Randomise combined samples")
58 fileGroup = parser.add_mutually_exclusive_group()
59 fileGroup.add_argument("-t", "--text",
60  action="store_false", dest="hdf", default=False,
61  help="Use ASCII posterior (.dat) files [Default]")
62 fileGroup.add_argument("-f", "--hdf",
63  action="store_true", dest="hdf",
64  help="Use HDF5 posterior files")
65 
66 
67 args = parser.parse_args()
68 
69 #Count arguments
70 nPos = np.size(args.infilename)
71 nWeight = np.size(args.weightings)
72 
73 print(args.weightings)
74 
75 #Check sensible combination of arguments
76 if (nWeight != 0):
77  if args.all:
78  print("You cannot use all posterior samples and weight them!")
79  exit(1)
80 
81  if (nWeight != nPos):
82  print("Please either specify a weight for each posterior file or none")
83  exit(1)
84 else:
85  args.weightings = [1.0] * nPos
86 
87 #Specify combination ID
88 combineID = "combined"
89 if args.all:
90  combineID = combineID+"_all"
91 else:
92  combineID = combineID+"_weight_"+'_'.join(map(str, args.weightings))
93 
94 if args.shuffle:
95  combineID = combineID+"_shuffle"
96 else:
97  combineID = combineID+"_noshuffle"
98 
99 if args.mix:
100  combineID = combineID+"_mixed"
101 
102 print("Combined ID:", combineID)
103 
104 #Initiate lists to hold data
105 samples = []
106 paramsList = []
107 sizeList = []
108 metadata = {"lalinference": defaultdict(lambda: [None]*nPos),
109  "lalinference/"+combineID: defaultdict(lambda: [None]*nPos),
110  "lalinference/"+combineID+"/"+posterior_grp_name: defaultdict(lambda: [None]*nPos)}
111 
112 #Read in data
113 for posIndex in range(nPos):
114  if args.hdf:
115  #HDF5 files with metadata
116  with h5py.File(args.infilename[posIndex], "r") as inFile:
117 
118  group = inFile["lalinference"]
119  for key in group.attrs:
120  metadata["lalinference"][key][posIndex] = group.attrs[key]
121 
122  run_id = list(group.keys())[0]
123 
124  group = group[run_id]
125 
126  for key in group.attrs:
127  metadata["lalinference/"+combineID][key][posIndex] = group.attrs[key]
128 
129  if "combined_run_ids" not in group.attrs:
130  metadata["lalinference/"+combineID]["combined_run_ids"][posIndex] = run_id
131  elif "recombined_run_ids" not in group.attrs:
132  metadata["lalinference/"+combineID]["recombined_run_ids"][posIndex] = run_id
133  elif "rerecombined_run_ids" not in group.attrs:
134  metadata["lalinference/"+combineID]["rerecombined_run_ids"][posIndex] = run_id
135  elif "rererecombined_run_ids" not in group.attrs:
136  metadata["lalinference/"+combineID]["rererecombined_run_ids"][posIndex] = True
137  else:
138  print("Too many combinations to count!")
139 
140 
141  group = group[posterior_grp_name]
142  for key in group.attrs:
143  metadata["lalinference/"+combineID+posterior_grp_name][key][posIndex] = group.attrs[key]
144 
145  posDtype = []
146  for key in group:
147  posDtype.append((key, group[key].dtype))
148  shape = group[key].shape
149 
150  posData = np.empty(shape, dtype=posDtype)
151 
152  for key in group:
153  posData[key] = group[key][:]
154 
155  else:
156  #Standard text file
157  posData = np.genfromtxt(args.infilename[posIndex], names=True)
158 
159  if (args.shuffle):
160  np.random.shuffle(posData)
161 
162  samples.append(posData)
163  paramsList.append(set(posData.dtype.names))
164  sizeList.append(np.size(posData))
165 
166 
167 #Create intersection
168 paramsOut = list(set.intersection(*paramsList))
169 
170 datatypes = samples[0][paramsOut].dtype
171 
172 #Combine posteriors
173 if (args.all):
174  #Use all samples
175  sizeOut = sum(sizeList)
176  samplesOut = np.empty(sizeOut, dtype=datatypes)
177 
178  indexSize = sizeList
179 
180 else:
181  #Weight different posteriors
182  fracWeight = np.asarray(args.weightings) / float(sum(args.weightings))
183 
184  testNum = fracWeight * float(sum(sizeList))
185  minIndex = np.argmin(np.asarray(sizeList) / np.asarray(testNum))
186 
187  testSize = sizeList[minIndex] / fracWeight[minIndex]
188 
189  weightNum = np.around(fracWeight * testSize).astype(int)
190  sizeOut = sum(weightNum)
191  samplesOut = np.empty(sizeOut, dtype=datatypes)
192 
193  indexSize = weightNum
194 
195 
196 print("Using number of samples ", indexSize)
197 
198 
199 startIndex = 0
200 for posIndex in range(0,nPos):
201  stopIndex = startIndex + indexSize[posIndex]
202 
203  for paramIndex, paramItem in enumerate(paramsOut):
204  samplesOut[paramItem][startIndex:stopIndex] = samples[posIndex][paramItem][0:indexSize[posIndex]]
205 
206  startIndex = stopIndex
207 
208 
209 #Mix samples
210 if args.mix:
211  np.random.shuffle(samplesOut)
212 
213 
214 #Save output
215 if args.hdf:
216  #HDF5 file with metadata
217  with h5py.File(path, "w") as outFile:
218  group = outFile.create_group("lalinference")
219  group = group.create_group(combineID)
220  group = group.create_group(posterior_grp_name)
221  for key in samplesOut.dtype.names:
222  group.create_dataset(key, data=samplesOut[key], shuffle=True, compression="gzip")
223 
224  for level in metadata:
225  for key in metadata[level]:
226  outFile[level].attrs[key] = metadata[level][key]
227 
228 else:
229  #Standard textt output
230  paramHeader = "\t".join(paramsOut)
231  np.savetxt(args.outfilename, samplesOut.T, delimiter="\t", header=paramHeader, comments="")
232 
233 
234 #Done!
235