24 This is used for combining posterior samples together.
30 from lalinference
import LALInferenceHDF5PosteriorSamplesDatasetName
as posterior_grp_name
32 posterior_grp_name =
"posterior_samples"
33 from collections
import defaultdict
38 parser = argparse.ArgumentParser(description=
"Combine some posterior samples.")
39 parser.add_argument(
"-o",
"--output", dest=
"outfilename", default=
"combined_posterior",
40 help=
"Write combined posterior to OUTFILE", metavar=
"OUTFILE")
41 parser.add_argument(
"-p",
"--pos", action=
"append", dest=
"infilename",
42 help=
"Combine posteriors from INFILE", metavar=
"INFILE", required=
True)
43 parser.add_argument(
"-a",
"--all", dest=
"all",
44 action=
"store_true", default=
False,
45 help=
"Use all posterior samples (do not weight)")
46 parser.add_argument(
"-w",
"--weight", action=
"append", default=[], dest=
"weightings",
47 help=
"Weighting for posteriors", type=float)
48 shuffleGroup = parser.add_mutually_exclusive_group()
49 shuffleGroup.add_argument(
"-s",
"--shuffle",
50 action=
"store_true", dest=
"shuffle", default=
True,
51 help=
"Randomise posterior samples before combination [Default]")
52 shuffleGroup.add_argument(
"-ns",
"--no-shuffle",
53 action=
"store_false", dest=
"shuffle",
54 help=
"Do not randomise posterior samples before combination")
55 parser.add_argument(
"-m",
"--mix", dest=
"mix",
56 action=
"store_true", default=
False,
57 help=
"Randomise combined samples")
58 fileGroup = parser.add_mutually_exclusive_group()
59 fileGroup.add_argument(
"-t",
"--text",
60 action=
"store_false", dest=
"hdf", default=
False,
61 help=
"Use ASCII posterior (.dat) files [Default]")
62 fileGroup.add_argument(
"-f",
"--hdf",
63 action=
"store_true", dest=
"hdf",
64 help=
"Use HDF5 posterior files")
67 args = parser.parse_args()
70 nPos = np.size(args.infilename)
71 nWeight = np.size(args.weightings)
73 print(args.weightings)
78 print(
"You cannot use all posterior samples and weight them!")
82 print(
"Please either specify a weight for each posterior file or none")
85 args.weightings = [1.0] * nPos
88 combineID =
"combined"
90 combineID = combineID+
"_all"
92 combineID = combineID+
"_weight_"+
'_'.join(map(str, args.weightings))
95 combineID = combineID+
"_shuffle"
97 combineID = combineID+
"_noshuffle"
100 combineID = combineID+
"_mixed"
102 print(
"Combined ID:", combineID)
108 metadata = {
"lalinference": defaultdict(
lambda: [
None]*nPos),
109 "lalinference/"+combineID: defaultdict(
lambda: [
None]*nPos),
110 "lalinference/"+combineID+
"/"+posterior_grp_name: defaultdict(
lambda: [
None]*nPos)}
113 for posIndex
in range(nPos):
116 with h5py.File(args.infilename[posIndex],
"r")
as inFile:
118 group = inFile[
"lalinference"]
119 for key
in group.attrs:
120 metadata[
"lalinference"][key][posIndex] = group.attrs[key]
122 run_id = list(group.keys())[0]
124 group = group[run_id]
126 for key
in group.attrs:
127 metadata[
"lalinference/"+combineID][key][posIndex] = group.attrs[key]
129 if "combined_run_ids" not in group.attrs:
130 metadata[
"lalinference/"+combineID][
"combined_run_ids"][posIndex] = run_id
131 elif "recombined_run_ids" not in group.attrs:
132 metadata[
"lalinference/"+combineID][
"recombined_run_ids"][posIndex] = run_id
133 elif "rerecombined_run_ids" not in group.attrs:
134 metadata[
"lalinference/"+combineID][
"rerecombined_run_ids"][posIndex] = run_id
135 elif "rererecombined_run_ids" not in group.attrs:
136 metadata[
"lalinference/"+combineID][
"rererecombined_run_ids"][posIndex] =
True
138 print(
"Too many combinations to count!")
141 group = group[posterior_grp_name]
142 for key
in group.attrs:
143 metadata[
"lalinference/"+combineID+posterior_grp_name][key][posIndex] = group.attrs[key]
147 posDtype.append((key, group[key].dtype))
148 shape = group[key].shape
150 posData = np.empty(shape, dtype=posDtype)
153 posData[key] = group[key][:]
157 posData = np.genfromtxt(args.infilename[posIndex], names=
True)
160 np.random.shuffle(posData)
162 samples.append(posData)
163 paramsList.append(set(posData.dtype.names))
164 sizeList.append(np.size(posData))
168 paramsOut = list(set.intersection(*paramsList))
170 datatypes = samples[0][paramsOut].dtype
175 sizeOut = sum(sizeList)
176 samplesOut = np.empty(sizeOut, dtype=datatypes)
182 fracWeight = np.asarray(args.weightings) / float(sum(args.weightings))
184 testNum = fracWeight * float(sum(sizeList))
185 minIndex = np.argmin(np.asarray(sizeList) / np.asarray(testNum))
187 testSize = sizeList[minIndex] / fracWeight[minIndex]
189 weightNum = np.around(fracWeight * testSize).astype(int)
190 sizeOut = sum(weightNum)
191 samplesOut = np.empty(sizeOut, dtype=datatypes)
193 indexSize = weightNum
196 print(
"Using number of samples ", indexSize)
200 for posIndex
in range(0,nPos):
201 stopIndex = startIndex + indexSize[posIndex]
203 for paramIndex, paramItem
in enumerate(paramsOut):
204 samplesOut[paramItem][startIndex:stopIndex] = samples[posIndex][paramItem][0:indexSize[posIndex]]
206 startIndex = stopIndex
211 np.random.shuffle(samplesOut)
217 with h5py.File(path,
"w")
as outFile:
218 group = outFile.create_group(
"lalinference")
219 group = group.create_group(combineID)
220 group = group.create_group(posterior_grp_name)
221 for key
in samplesOut.dtype.names:
222 group.create_dataset(key, data=samplesOut[key], shuffle=
True, compression=
"gzip")
224 for level
in metadata:
225 for key
in metadata[level]:
226 outFile[level].attrs[key] = metadata[level][key]
230 paramHeader =
"\t".join(paramsOut)
231 np.savetxt(args.outfilename, samplesOut.T, delimiter=
"\t", header=paramHeader, comments=
"")