Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 10, 2024
1 parent 6142073 commit bc19878
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/datatrove/executor/slurm_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def run(self):
"""
if "SLURM_ARRAY_TASK_ID" in os.environ:
# we are already "inside" the slurm task, get our rank from env vars and run pipeline
# logger.warning(f'{os.environ["SLURM_ARRAY_TASK_ID"]=}, {self.max_array_size=}, {os.environ.get("RUN_OFFSET", 0)=}, ')

slurm_rank = (
int(os.environ["SLURM_ARRAY_TASK_ID"])
+ self.max_array_size * int(os.environ.get("RUN_OFFSET", 0)) * self.tasks_per_node
+ int(os.environ.get("SLURM_PROCID"))
)
int(os.environ["SLURM_ARRAY_TASK_ID"]) + self.max_array_size * int(os.environ.get("RUN_OFFSET", 0))
) * self.tasks_per_node + int(os.environ.get("SLURM_PROCID"))

ranks_to_run_range = (slurm_rank * self.tasks_per_job, (slurm_rank + 1) * self.tasks_per_job)
with self.logging_dir.open("ranks_to_run.json", "r") as ranks_to_run_file:
Expand Down Expand Up @@ -262,7 +262,7 @@ def launch_job(self):
srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else ""
launch_file_contents = self.get_launch_file_contents(
self.get_sbatch_args(max_array),
f"srun {srun_args_str} --environment=datatrove --ntasks={self.tasks_per_node} --cpus-per-task={self.cpus_per_task} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
f"srun {srun_args_str} --environment=datatrove -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
)
# save it
with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f:
Expand Down Expand Up @@ -298,8 +298,8 @@ def get_sbatch_args(self, max_array: int = 1) -> dict:
os.makedirs(self.slurm_logs_folder, exist_ok=True)
slurm_logfile = os.path.join(self.slurm_logs_folder, "%A_%a.out")
sbatch_args = {
# "cpus-per-task": self.cpus_per_node,
"ntasks-per-node": 1,
"cpus-per-task": self.cpus_per_task,
"ntasks-per-node": self.tasks_per_node,
"nodes": 1,
# "mem-per-cpu": f"{self.mem_per_cpu_gb}G",
"partition": self.partition,
Expand Down

0 comments on commit bc19878

Please sign in to comment.