Skip to content

Commit

Permalink
Only load model parameters on SFT
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 16, 2024
1 parent ed51183 commit cd81111
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ class OptimizerArgs:
clip_grad: Optional[float]
accumulate_grad_in_fp32: bool
learning_rate_scheduler: LRSchedulerArgs
sft: bool = False


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
optimizer_args=self.config.optimizer,
parallel_context=self.parallel_context,
)
if self.init_checkpoint_path is not None:
if self.init_checkpoint_path is not None and not self.config.optimizer.sft:
load_optimizer(
optimizer=self.optimizer,
parallel_context=self.parallel_context,
Expand All @@ -206,7 +206,7 @@ def __init__(
lr_scheduler_args=self.config.optimizer.learning_rate_scheduler,
total_training_steps=self.config.tokens.train_steps,
)
if self.init_checkpoint_path is not None:
if self.init_checkpoint_path is not None and not self.config.optimizer.sft:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
parallel_context=self.parallel_context,
Expand Down

0 comments on commit cd81111

Please sign in to comment.