25This is used for combining posterior samples together.
31 from lalinference
import LALInferenceHDF5PosteriorSamplesDatasetName
as posterior_grp_name
33 posterior_grp_name =
"posterior_samples"
34from collections
import defaultdict
39parser = argparse.ArgumentParser(description=
"Combine some posterior samples.")
40parser.add_argument(
"-o",
"--output", dest=
"outfilename", default=
"combined_posterior",
41 help=
"Write combined posterior to OUTFILE", metavar=
"OUTFILE")
42parser.add_argument(
"-p",
"--pos", action=
"append", dest=
"infilename",
43 help=
"Combine posteriors from INFILE", metavar=
"INFILE", required=
True)
44parser.add_argument(
"-a",
"--all", dest=
"all",
45 action=
"store_true", default=
False,
46 help=
"Use all posterior samples (do not weight)")
47parser.add_argument(
"-w",
"--weight", action=
"append", default=[], dest=
"weightings",
48 help=
"Weighting for posteriors", type=float)
49shuffleGroup = parser.add_mutually_exclusive_group()
50shuffleGroup.add_argument(
"-s",
"--shuffle",
51 action=
"store_true", dest=
"shuffle", default=
True,
52 help=
"Randomise posterior samples before combination [Default]")
53shuffleGroup.add_argument(
"-ns",
"--no-shuffle",
54 action=
"store_false", dest=
"shuffle",
55 help=
"Do not randomise posterior samples before combination")
56parser.add_argument(
"-m",
"--mix", dest=
"mix",
57 action=
"store_true", default=
False,
58 help=
"Randomise combined samples")
59fileGroup = parser.add_mutually_exclusive_group()
60fileGroup.add_argument(
"-t",
"--text",
61 action=
"store_false", dest=
"hdf", default=
False,
62 help=
"Use ASCII posterior (.dat) files [Default]")
63fileGroup.add_argument(
"-f",
"--hdf",
64 action=
"store_true", dest=
"hdf",
65 help=
"Use HDF5 posterior files")
68args = parser.parse_args()
71nPos = np.size(args.infilename)
72nWeight = np.size(args.weightings)
79 print(
"You cannot use all posterior samples and weight them!")
83 print(
"Please either specify a weight for each posterior file or none")
86 args.weightings = [1.0] * nPos
91 combineID = combineID+
"_all"
93 combineID = combineID+
"_weight_"+
'_'.join(map(str, args.weightings))
96 combineID = combineID+
"_shuffle"
98 combineID = combineID+
"_noshuffle"
101 combineID = combineID+
"_mixed"
103print(
"Combined ID:", combineID)
109metadata = {
"lalinference": defaultdict(
lambda: [
None]*nPos),
110 "lalinference/"+combineID: defaultdict(
lambda: [
None]*nPos),
111 "lalinference/"+combineID+
"/"+posterior_grp_name: defaultdict(
lambda: [
None]*nPos)}
114for posIndex
in range(nPos):
117 with h5py.File(args.infilename[posIndex],
"r")
as inFile:
119 group = inFile[
"lalinference"]
120 for key
in group.attrs:
121 metadata[
"lalinference"][key][posIndex] = group.attrs[key]
123 run_id = list(group.keys())[0]
125 group = group[run_id]
127 for key
in group.attrs:
128 metadata[
"lalinference/"+combineID][key][posIndex] = group.attrs[key]
130 if "combined_run_ids" not in group.attrs:
131 metadata[
"lalinference/"+combineID][
"combined_run_ids"][posIndex] = run_id
132 elif "recombined_run_ids" not in group.attrs:
133 metadata[
"lalinference/"+combineID][
"recombined_run_ids"][posIndex] = run_id
134 elif "rerecombined_run_ids" not in group.attrs:
135 metadata[
"lalinference/"+combineID][
"rerecombined_run_ids"][posIndex] = run_id
136 elif "rererecombined_run_ids" not in group.attrs:
137 metadata[
"lalinference/"+combineID][
"rererecombined_run_ids"][posIndex] =
True
139 print(
"Too many combinations to count!")
142 group = group[posterior_grp_name]
143 for key
in group.attrs:
144 metadata[
"lalinference/"+combineID+posterior_grp_name][key][posIndex] = group.attrs[key]
148 posDtype.append((key, group[key].dtype))
149 shape = group[key].shape
151 posData = np.empty(shape, dtype=posDtype)
154 posData[key] = group[key][:]
158 posData = np.genfromtxt(args.infilename[posIndex], names=
True)
161 np.random.shuffle(posData)
163 samples.append(posData)
164 paramsList.append(set(posData.dtype.names))
165 sizeList.append(np.size(posData))
169paramsOut = list(set.intersection(*paramsList))
171datatypes = samples[0][paramsOut].dtype
176 sizeOut = sum(sizeList)
177 samplesOut = np.empty(sizeOut, dtype=datatypes)
183 fracWeight = np.asarray(args.weightings) / float(sum(args.weightings))
185 testNum = fracWeight * float(sum(sizeList))
186 minIndex = np.argmin(np.asarray(sizeList) / np.asarray(testNum))
188 testSize = sizeList[minIndex] / fracWeight[minIndex]
190 weightNum = np.around(fracWeight * testSize).astype(int)
191 sizeOut = sum(weightNum)
192 samplesOut = np.empty(sizeOut, dtype=datatypes)
194 indexSize = weightNum
197print(
"Using number of samples ", indexSize)
201for posIndex
in range(0,nPos):
202 stopIndex = startIndex + indexSize[posIndex]
204 for paramIndex, paramItem
in enumerate(paramsOut):
205 samplesOut[paramItem][startIndex:stopIndex] = samples[posIndex][paramItem][0:indexSize[posIndex]]
207 startIndex = stopIndex
212 np.random.shuffle(samplesOut)
218 with h5py.File(path,
"w")
as outFile:
219 group = outFile.create_group(
"lalinference")
220 group = group.create_group(combineID)
221 group = group.create_group(posterior_grp_name)
222 for key
in samplesOut.dtype.names:
223 group.create_dataset(key, data=samplesOut[key], shuffle=
True, compression=
"gzip")
225 for level
in metadata:
226 for key
in metadata[level]:
227 outFile[level].attrs[key] = metadata[level][key]
231 paramHeader =
"\t".join(paramsOut)
232 np.savetxt(args.outfilename, samplesOut.T, delimiter=
"\t", header=paramHeader, comments=
"")