Skip to content

Commit

Permalink
Introduced resume flag and checkpoint loading for transfer learning, …
Browse files Browse the repository at this point in the history
…removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting
  • Loading branch information
icedoom888 committed Oct 9, 2024
1 parent a599dfd commit 7e238e4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s

torch.save(model, inference_checkpoint_filepath)

save_metadata(inference_checkpoint_filepath, metadata)
# save_metadata(inference_checkpoint_filepath, metadata)

model.config = save_config
model.metadata = tmp_metadata
Expand All @@ -1016,7 +1016,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s
from weakref import proxy

# save metadata for the training checkpoint in the same format as inference
save_metadata(lightning_checkpoint_filepath, metadata)
# save_metadata(lightning_checkpoint_filepath, metadata)

# notify loggers
for logger in trainer.loggers:
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/training/diagnostics/mlflow/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N
if len(sys.argv) > 1:
# add the arguments to the command tag
tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:])
if config.training.run_id or config.training.fork_run_id:
if (config.training.run_id or config.training.fork_run_id) and config.training.resume:
"Either run_id or fork_run_id must be provided to resume a run."

import mlflow
Expand All @@ -85,11 +85,13 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N
run_name = mlflow_client.get_run(parent_run_id).info.run_name
tags["mlflow.parentRunId"] = parent_run_id
tags["resumedRun"] = "True" # tags can't take boolean values

elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child:
run_id = config.training.run_id
run_name = mlflow_client.get_run(run_id).info.run_name
mlflow_client.update_run(run_id=run_id, status="RUNNING")
tags["resumedRun"] = "True"

else:
parent_run_id = config.training.fork_run_id
tags["forkedRun"] = "True"
Expand Down
25 changes: 19 additions & 6 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, config: DictConfig) -> None:
self.config = config

# Default to not warm-starting from a checkpoint
self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)
self.start_from_checkpoint = (bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)) and self.config.training.resume
self.load_weights_only = config.training.load_weights_only
self.parent_uuid = None

Expand Down Expand Up @@ -141,7 +141,8 @@ def model(self) -> GraphForecaster:
}
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)

return GraphForecaster(**kwargs)

@rank_zero_only
Expand Down Expand Up @@ -187,12 +188,19 @@ def last_checkpoint(self) -> str | None:
"""Path to the last checkpoint."""
if not self.start_from_checkpoint:
return None

checkpoint = Path(
self.config.hardware.paths.checkpoints.parent,
self.config.training.fork_run_id or self.run_id,
self.config.hardware.files.warm_start or "last.ckpt",
self.config.training.fork_run_id, # or self.run_id,
self.config.hardware.files.warm_start or "transfer.ckpt",
)
# Transfer learning or continue training
if not Path(checkpoint).exists():
checkpoint = Path(
self.config.hardware.paths.checkpoints.parent,
self.config.training.fork_run_id, # or self.run_id,
self.config.hardware.files.warm_start or "last.ckpt",
)

# Check if the last checkpoint exists
if Path(checkpoint).exists():
Expand Down Expand Up @@ -313,6 +321,9 @@ def strategy(self) -> DDPGroupStrategy:

def train(self) -> None:
"""Training entry point."""

print('Setting up trainer..')

trainer = pl.Trainer(
accelerator=self.accelerator,
callbacks=self.callbacks,
Expand All @@ -328,7 +339,7 @@ def train(self) -> None:
# run a fixed no of batches per epoch (helpful when debugging)
limit_train_batches=self.config.dataloader.limit_batches.training,
limit_val_batches=self.config.dataloader.limit_batches.validation,
num_sanity_val_steps=4,
num_sanity_val_steps=0,
accumulate_grad_batches=self.config.training.accum_grad_batches,
gradient_clip_val=self.config.training.gradient_clip.val,
gradient_clip_algorithm=self.config.training.gradient_clip.algorithm,
Expand All @@ -338,6 +349,8 @@ def train(self) -> None:
enable_progress_bar=self.config.diagnostics.enable_progress_bar,
)

print('Starting training..')

trainer.fit(
self.model,
datamodule=self.datamodule,
Expand Down

0 comments on commit 7e238e4

Please sign in to comment.