diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 4f8cc1c2..4e72b330 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -322,6 +322,7 @@ class OptimizerArgs: clip_grad: Optional[float] accumulate_grad_in_fp32: bool learning_rate_scheduler: LRSchedulerArgs + sft: bool = False @dataclass diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3994ddd3..dc074510 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -191,7 +191,7 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and not self.config.optimizer.sft: load_optimizer( optimizer=self.optimizer, parallel_context=self.parallel_context, @@ -206,7 +206,7 @@ def __init__( lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, total_training_steps=self.config.tokens.train_steps, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and not self.config.optimizer.sft: load_lr_scheduler( lr_scheduler=self.lr_scheduler, parallel_context=self.parallel_context,