Skip to content

Commit

Permalink
allow setting of memory and use new deterministic ml_collections-base…
Browse files Browse the repository at this point in the history
…d script
  • Loading branch information
henryaddison committed Jul 13, 2023
1 parent 53ec916 commit 8ebd2e3
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions bin/deterministic/queue-training
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,31 @@ app = typer.Typer()


def train_cmd(dataset, workdir, config_overrides=list):
train_basecmd = ["python", "bin/deterministic/train-model.py"]
train_basecmd = ["python", "bin/deterministic/main.py"]

train_opts = {
"--dataset": dataset,
"--config": "src/ml_downscaling_emulator/deterministic/configs/default.py",
"--workdir": workdir,
"--mode": "train",
}

return (
train_basecmd
+ [workdir]
+ [arg for item in train_opts.items() for arg in item]
+ [f"--config.data.dataset_name={dataset}"]
+ config_overrides
)


def queue_cmd(training_duration):
def queue_cmd(duration, memory):
queue_basecmd = ["lbatch"]

queue_opts = {
"-a": os.getenv("HPC_PROJECT_CODE"),
"-g": "1",
"-m": "16",
"-m": str(memory),
"-q": "cnu,gpu",
"-t": str(training_duration),
"-t": str(duration),
"--condaenv": "cuda-downscaling",
}

Expand All @@ -46,16 +48,20 @@ def queue_cmd(training_duration):
"ignore_unknown_options": True,
}
)
def main(ctx: typer.Context, model_run_id: str, cpm_dataset: str):
def main(
ctx: typer.Context,
model_run_id: str,
cpm_dataset: str,
memory: int = 64,
duration: int = 72,
):
# Add any other config on the commandline for training
# --config.data.input_transform_key=spatial

training_duration = 36

workdir = f"{os.getenv('DERIVED_DATA')}/workdirs/u-net/{model_run_id}"

full_cmd = (
queue_cmd(training_duration)
queue_cmd(duration=duration, memory=memory)
+ ["--"]
+ train_cmd(cpm_dataset, workdir, ctx.args)
)
Expand Down

0 comments on commit 8ebd2e3

Please sign in to comment.