Skip to content

Commit

Permalink
Forgot the trainer ups
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 16, 2024
1 parent d9f0670 commit efe8720
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ def find_stage_idx_to_resume():

def train(
self,
dataloader_or_dls: Dict[
train_dataloader_or_dls: Dict[
str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]
],
valid_dataloader_or_dls: Dict[
str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]
],
**kwargs,
Expand Down Expand Up @@ -424,7 +427,7 @@ def train(
prof.step()

self.iteration_start_time = time.time()
self._update_dataloader_based_on_training_stages(dataloader_or_dls)
self._update_dataloader_based_on_training_stages(train_dataloader_or_dls)

# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
Expand Down

0 comments on commit efe8720

Please sign in to comment.