Coverage for parallel_bilby/slurm/slurm.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

127 statements  

1import os 

2from argparse import Namespace 

3from os.path import abspath 

4 

5import jinja2 

6from parallel_bilby.parser import create_analysis_parser 

7 

8DIR = os.path.dirname(__file__) 

9TEMPLATE_SLURM = "template_slurm.sh" 

10 

11 

12def load_template(template_file: str): 

13 template_loader = jinja2.FileSystemLoader(searchpath=DIR) 

14 template_env = jinja2.Environment(loader=template_loader) 

15 template = template_env.get_template(template_file) 

16 return template 

17 

18 

19def setup_submit(data_dump_file, inputs, args, cli_args): 

20 # Create analysis nodes 

21 analysis_nodes = [] 

22 for idx in range(args.n_parallel): 

23 node = AnalysisNode(data_dump_file, inputs, idx, args, cli_args) 

24 node.write() 

25 analysis_nodes.append(node) 

26 

27 if len(analysis_nodes) > 1: 

28 final_analysis_node = MergeNodes(analysis_nodes, inputs, args) 

29 final_analysis_node.write() 

30 else: 

31 final_analysis_node = analysis_nodes[0] 

32 

33 bash_script = f"{inputs.submit_directory}/bash_{inputs.label}.sh" 

34 with open(bash_script, "w+") as ff: 

35 dependent_job_ids = [] 

36 for ii, node in enumerate(analysis_nodes): 

37 print(f"jid{ii}=$(sbatch {node.filename})", file=ff) 

38 dependent_job_ids.append(f"${{jid{ii}##* }}") 

39 if len(analysis_nodes) > 1: 

40 print( 

41 f"sbatch --dependency=afterok:{':'.join(dependent_job_ids)} " 

42 f"{final_analysis_node.filename}", 

43 file=ff, 

44 ) 

45 print('squeue -u $USER -o "%u %.10j %.8A %.4C %.40E %R"', file=ff) 

46 

47 return bash_script 

48 

49 

50class BaseNode(object): 

51 def __init__(self, inputs: Namespace, args: Namespace): 

52 self.inputs = inputs 

53 self.args = args 

54 

55 self.nodes = self.args.nodes 

56 self.ntasks_per_node = self.args.ntasks_per_node 

57 self.time = self.args.time 

58 self.mem_per_cpu = self.args.mem_per_cpu 

59 

60 def get_contents(self, command): 

61 template = load_template(TEMPLATE_SLURM) 

62 log_file = f"{self.logs}/{self.job_name}_%j.log" 

63 

64 if self.args.slurm_extra_lines is not None: 

65 slurm_extra_lines = "\n".join( 

66 [f"#SBATCH --{lin}" for lin in self.args.slurm_extra_lines.split()] 

67 ) 

68 else: 

69 slurm_extra_lines = "" 

70 

71 if self.mem_per_cpu is not None: 

72 slurm_extra_lines += f"\n#SBATCH --mem-per-cpu={self.mem_per_cpu}" 

73 

74 if self.args.extra_lines: 

75 bash_extra_lines = self.args.extra_lines.split(";") 

76 bash_extra_lines = "\n".join([line.strip() for line in bash_extra_lines]) 

77 else: 

78 bash_extra_lines = "" 

79 

80 file_contents = template.render( 

81 job_name=self.job_name, 

82 nodes=self.nodes, 

83 ntasks_per_node=self.ntasks_per_node, 

84 time=self.time, 

85 log_file=log_file, 

86 mem_per_cpu=self.mem_per_cpu, 

87 slurm_extra_lines=slurm_extra_lines, 

88 bash_extra_lines=bash_extra_lines, 

89 command=command, 

90 ) 

91 

92 return file_contents 

93 

94 def write(self): 

95 content = self.get_contents() 

96 with open(self.filename, "w+") as f: 

97 print(content, file=f) 

98 

99 

100class AnalysisNode(BaseNode): 

101 def __init__(self, data_dump_file, inputs, idx, args, cli_args): 

102 super().__init__(inputs, args) 

103 self.data_dump_file = data_dump_file 

104 

105 self.idx = idx 

106 self.filename = ( 

107 f"{self.inputs.submit_directory}/" 

108 f"analysis_{self.inputs.label}_{self.idx}.sh" 

109 ) 

110 self.job_name = f"{self.idx}_{self.inputs.label}" 

111 self.logs = self.inputs.data_analysis_log_directory 

112 

113 analysis_parser = create_analysis_parser(sampler=self.args.sampler) 

114 self.analysis_args, _ = analysis_parser.parse_known_args(args=cli_args) 

115 self.analysis_args.data_dump = self.data_dump_file 

116 

117 @property 

118 def executable(self): 

119 if self.args.sampler == "dynesty": 

120 return "parallel_bilby_analysis" 

121 else: 

122 raise ValueError( 

123 f"Unable to determine sampler to use from {self.args.sampler}" 

124 ) 

125 

126 @property 

127 def label(self): 

128 return f"{self.inputs.label}_{self.idx}" 

129 

130 @property 

131 def output_filename(self): 

132 return ( 

133 f"{self.inputs.result_directory}/" 

134 f"{self.inputs.label}_{self.idx}_result.{self.analysis_args.result_format}" 

135 ) 

136 

137 def get_contents(self): 

138 command = f"mpirun {self.executable} {self.get_run_string()}" 

139 return super().get_contents(command=command) 

140 

141 def get_run_string(self): 

142 run_list = [f"{self.data_dump_file}"] 

143 for key, val in vars(self.analysis_args).items(): 

144 if key in ["data_dump", "label", "outdir", "sampling_seed"]: 

145 continue 

146 input_val = getattr(self.args, key) 

147 if val != input_val: 

148 key = key.replace("_", "-") 

149 if input_val is True: 

150 # For flags only add the flag 

151 run_list.append(f"--{key}") 

152 elif isinstance(input_val, list): 

153 # For lists add each entry individually 

154 for entry in input_val: 

155 run_list.append(f"--{key} {entry}") 

156 else: 

157 run_list.append(f"--{key} {input_val}") 

158 

159 run_list.append(f"--label {self.label}") 

160 run_list.append(f"--sampling-seed {self.inputs.sampling_seed + self.idx}") 

161 run_list.append(f"--outdir {abspath(self.inputs.result_directory)}") 

162 

163 return " ".join(run_list) 

164 

165 

166class MergeNodes(BaseNode): 

167 def __init__(self, analysis_nodes, inputs, args): 

168 super().__init__(inputs, args) 

169 self.analysis_nodes = analysis_nodes 

170 self.job_name = f"merge_{self.inputs.label}" 

171 self.nodes = 1 

172 self.ntasks_per_node = 1 

173 self.time = "1:00:00" 

174 self.mem_per_cpu = "16GB" 

175 self.logs = self.inputs.data_analysis_log_directory 

176 self.filename = f"{self.inputs.submit_directory}/merge_{self.inputs.label}.sh" 

177 

178 @property 

179 def file_list(self): 

180 return " ".join([node.output_filename for node in self.analysis_nodes]) 

181 

182 @property 

183 def merged_result_label(self): 

184 return f"{self.inputs.label}_merged" 

185 

186 def get_contents(self): 

187 command = [] 

188 command.append(f"bilby_result -r {self.file_list}") 

189 command.append("--merge") 

190 command.append(f"--label {self.merged_result_label}") 

191 command.append(f"--outdir {self.inputs.result_directory}") 

192 command.append(f"-e {self.args.result_format}") 

193 command = " ".join(command) 

194 return super().get_contents(command=command)