Coverage for parallel_bilby/slurm/slurm.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2from argparse import Namespace
3from os.path import abspath
5import jinja2
6from parallel_bilby.parser import create_analysis_parser
8DIR = os.path.dirname(__file__)
9TEMPLATE_SLURM = "template_slurm.sh"
12def load_template(template_file: str):
13 template_loader = jinja2.FileSystemLoader(searchpath=DIR)
14 template_env = jinja2.Environment(loader=template_loader)
15 template = template_env.get_template(template_file)
16 return template
19def setup_submit(data_dump_file, inputs, args, cli_args):
20 # Create analysis nodes
21 analysis_nodes = []
22 for idx in range(args.n_parallel):
23 node = AnalysisNode(data_dump_file, inputs, idx, args, cli_args)
24 node.write()
25 analysis_nodes.append(node)
27 if len(analysis_nodes) > 1:
28 final_analysis_node = MergeNodes(analysis_nodes, inputs, args)
29 final_analysis_node.write()
30 else:
31 final_analysis_node = analysis_nodes[0]
33 bash_script = f"{inputs.submit_directory}/bash_{inputs.label}.sh"
34 with open(bash_script, "w+") as ff:
35 dependent_job_ids = []
36 for ii, node in enumerate(analysis_nodes):
37 print(f"jid{ii}=$(sbatch {node.filename})", file=ff)
38 dependent_job_ids.append(f"${{jid{ii}##* }}")
39 if len(analysis_nodes) > 1:
40 print(
41 f"sbatch --dependency=afterok:{':'.join(dependent_job_ids)} "
42 f"{final_analysis_node.filename}",
43 file=ff,
44 )
45 print('squeue -u $USER -o "%u %.10j %.8A %.4C %.40E %R"', file=ff)
47 return bash_script
50class BaseNode(object):
51 def __init__(self, inputs: Namespace, args: Namespace):
52 self.inputs = inputs
53 self.args = args
55 self.nodes = self.args.nodes
56 self.ntasks_per_node = self.args.ntasks_per_node
57 self.time = self.args.time
58 self.mem_per_cpu = self.args.mem_per_cpu
60 def get_contents(self, command):
61 template = load_template(TEMPLATE_SLURM)
62 log_file = f"{self.logs}/{self.job_name}_%j.log"
64 if self.args.slurm_extra_lines is not None:
65 slurm_extra_lines = "\n".join(
66 [f"#SBATCH --{lin}" for lin in self.args.slurm_extra_lines.split()]
67 )
68 else:
69 slurm_extra_lines = ""
71 if self.mem_per_cpu is not None:
72 slurm_extra_lines += f"\n#SBATCH --mem-per-cpu={self.mem_per_cpu}"
74 if self.args.extra_lines:
75 bash_extra_lines = self.args.extra_lines.split(";")
76 bash_extra_lines = "\n".join([line.strip() for line in bash_extra_lines])
77 else:
78 bash_extra_lines = ""
80 file_contents = template.render(
81 job_name=self.job_name,
82 nodes=self.nodes,
83 ntasks_per_node=self.ntasks_per_node,
84 time=self.time,
85 log_file=log_file,
86 mem_per_cpu=self.mem_per_cpu,
87 slurm_extra_lines=slurm_extra_lines,
88 bash_extra_lines=bash_extra_lines,
89 command=command,
90 )
92 return file_contents
94 def write(self):
95 content = self.get_contents()
96 with open(self.filename, "w+") as f:
97 print(content, file=f)
100class AnalysisNode(BaseNode):
101 def __init__(self, data_dump_file, inputs, idx, args, cli_args):
102 super().__init__(inputs, args)
103 self.data_dump_file = data_dump_file
105 self.idx = idx
106 self.filename = (
107 f"{self.inputs.submit_directory}/"
108 f"analysis_{self.inputs.label}_{self.idx}.sh"
109 )
110 self.job_name = f"{self.idx}_{self.inputs.label}"
111 self.logs = self.inputs.data_analysis_log_directory
113 analysis_parser = create_analysis_parser(sampler=self.args.sampler)
114 self.analysis_args, _ = analysis_parser.parse_known_args(args=cli_args)
115 self.analysis_args.data_dump = self.data_dump_file
117 @property
118 def executable(self):
119 if self.args.sampler == "dynesty":
120 return "parallel_bilby_analysis"
121 else:
122 raise ValueError(
123 f"Unable to determine sampler to use from {self.args.sampler}"
124 )
126 @property
127 def label(self):
128 return f"{self.inputs.label}_{self.idx}"
130 @property
131 def output_filename(self):
132 return (
133 f"{self.inputs.result_directory}/"
134 f"{self.inputs.label}_{self.idx}_result.{self.analysis_args.result_format}"
135 )
137 def get_contents(self):
138 command = f"mpirun {self.executable} {self.get_run_string()}"
139 return super().get_contents(command=command)
141 def get_run_string(self):
142 run_list = [f"{self.data_dump_file}"]
143 for key, val in vars(self.analysis_args).items():
144 if key in ["data_dump", "label", "outdir", "sampling_seed"]:
145 continue
146 input_val = getattr(self.args, key)
147 if val != input_val:
148 key = key.replace("_", "-")
149 if input_val is True:
150 # For flags only add the flag
151 run_list.append(f"--{key}")
152 elif isinstance(input_val, list):
153 # For lists add each entry individually
154 for entry in input_val:
155 run_list.append(f"--{key} {entry}")
156 else:
157 run_list.append(f"--{key} {input_val}")
159 run_list.append(f"--label {self.label}")
160 run_list.append(f"--sampling-seed {self.inputs.sampling_seed + self.idx}")
161 run_list.append(f"--outdir {abspath(self.inputs.result_directory)}")
163 return " ".join(run_list)
166class MergeNodes(BaseNode):
167 def __init__(self, analysis_nodes, inputs, args):
168 super().__init__(inputs, args)
169 self.analysis_nodes = analysis_nodes
170 self.job_name = f"merge_{self.inputs.label}"
171 self.nodes = 1
172 self.ntasks_per_node = 1
173 self.time = "1:00:00"
174 self.mem_per_cpu = "16GB"
175 self.logs = self.inputs.data_analysis_log_directory
176 self.filename = f"{self.inputs.submit_directory}/merge_{self.inputs.label}.sh"
178 @property
179 def file_list(self):
180 return " ".join([node.output_filename for node in self.analysis_nodes])
182 @property
183 def merged_result_label(self):
184 return f"{self.inputs.label}_merged"
186 def get_contents(self):
187 command = []
188 command.append(f"bilby_result -r {self.file_list}")
189 command.append("--merge")
190 command.append(f"--label {self.merged_result_label}")
191 command.append(f"--outdir {self.inputs.result_directory}")
192 command.append(f"-e {self.args.result_format}")
193 command = " ".join(command)
194 return super().get_contents(command=command)