From 7e238e45a5376f4966b994272d93d0747add0c28 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Wed, 9 Oct 2024 15:25:03 +0200 Subject: [PATCH] Introduced resume flag and checkpoint loading for transfer learning, removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting --- .../diagnostics/callbacks/__init__.py | 4 +-- .../training/diagnostics/mlflow/logger.py | 4 ++- src/anemoi/training/train/train.py | 25 ++++++++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..f47db036 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -997,7 +997,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s torch.save(model, inference_checkpoint_filepath) - save_metadata(inference_checkpoint_filepath, metadata) + # save_metadata(inference_checkpoint_filepath, metadata) model.config = save_config model.metadata = tmp_metadata @@ -1016,7 +1016,7 @@ def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: s from weakref import proxy # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) + # save_metadata(lightning_checkpoint_filepath, metadata) # notify loggers for logger in trainer.loggers: diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7854c172..90978e76 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -70,7 +70,7 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N if len(sys.argv) > 1: # add the arguments to the command tag tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) - if config.training.run_id or config.training.fork_run_id: + if (config.training.run_id or config.training.fork_run_id) and config.training.resume: "Either run_id or fork_run_id must be provided to resume a run." import mlflow @@ -85,11 +85,13 @@ def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | N run_name = mlflow_client.get_run(parent_run_id).info.run_name tags["mlflow.parentRunId"] = parent_run_id tags["resumedRun"] = "True" # tags can't take boolean values + elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child: run_id = config.training.run_id run_name = mlflow_client.get_run(run_id).info.run_name mlflow_client.update_run(run_id=run_id, status="RUNNING") tags["resumedRun"] = "True" + else: parent_run_id = config.training.fork_run_id tags["forkedRun"] = "True" diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..00209a94 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -62,7 +62,7 @@ def __init__(self, config: DictConfig) -> None: self.config = config # Default to not warm-starting from a checkpoint - self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id) + self.start_from_checkpoint = (bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)) and self.config.training.resume self.load_weights_only = config.training.load_weights_only self.parent_uuid = None @@ -141,7 +141,8 @@ def model(self) -> GraphForecaster: } if self.load_weights_only: LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs) + return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) + return GraphForecaster(**kwargs) @rank_zero_only @@ -187,12 +188,19 @@ def last_checkpoint(self) -> str | None: """Path to the last checkpoint.""" if not self.start_from_checkpoint: return None - + checkpoint = Path( self.config.hardware.paths.checkpoints.parent, - self.config.training.fork_run_id or self.run_id, - self.config.hardware.files.warm_start or "last.ckpt", + self.config.training.fork_run_id, # or self.run_id, + self.config.hardware.files.warm_start or "transfer.ckpt", ) + # Transfer learning or continue training + if not Path(checkpoint).exists(): + checkpoint = Path( + self.config.hardware.paths.checkpoints.parent, + self.config.training.fork_run_id, # or self.run_id, + self.config.hardware.files.warm_start or "last.ckpt", + ) # Check if the last checkpoint exists if Path(checkpoint).exists(): @@ -313,6 +321,9 @@ def strategy(self) -> DDPGroupStrategy: def train(self) -> None: """Training entry point.""" + + print('Setting up trainer..') + trainer = pl.Trainer( accelerator=self.accelerator, callbacks=self.callbacks, @@ -328,7 +339,7 @@ def train(self) -> None: # run a fixed no of batches per epoch (helpful when debugging) limit_train_batches=self.config.dataloader.limit_batches.training, limit_val_batches=self.config.dataloader.limit_batches.validation, - num_sanity_val_steps=4, + num_sanity_val_steps=0, accumulate_grad_batches=self.config.training.accum_grad_batches, gradient_clip_val=self.config.training.gradient_clip.val, gradient_clip_algorithm=self.config.training.gradient_clip.algorithm, @@ -338,6 +349,8 @@ def train(self) -> None: enable_progress_bar=self.config.diagnostics.enable_progress_bar, ) + print('Starting training..') + trainer.fit( self.model, datamodule=self.datamodule,