From cd811115286db8007be17075d92c9e7e1809baad Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 16 Sep 2024 17:26:21 +0200 Subject: [PATCH] Only load model parameters on SFT --- src/nanotron/config/config.py | 1 + src/nanotron/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 4f8cc1c2..4e72b330 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -322,6 +322,7 @@ class OptimizerArgs: clip_grad: Optional[float] accumulate_grad_in_fp32: bool learning_rate_scheduler: LRSchedulerArgs + sft: bool = False @dataclass diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3994ddd3..dc074510 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -191,7 +191,7 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and not self.config.optimizer.sft: load_optimizer( optimizer=self.optimizer, parallel_context=self.parallel_context, @@ -206,7 +206,7 @@ def __init__( lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, total_training_steps=self.config.tokens.train_steps, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and not self.config.optimizer.sft: load_lr_scheduler( lr_scheduler=self.lr_scheduler, parallel_context=self.parallel_context,