diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 412571a1e7f..c8925c4bce6 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -636,10 +636,11 @@ def _train_loop(self) -> None: elif self._use_closures(): for optimizer in state.optimizers: if use_grad_scaling: - total_loss = state.scaler.step(optimizer, - closure=lambda: self._train_batch(microbatches)) + total_loss = state.scaler.step( + optimizer, closure=lambda **kwargs: self._train_batch(microbatches, **kwargs)) else: - total_loss = optimizer.step(closure=lambda: self._train_batch(microbatches).item()) + total_loss = optimizer.step( + closure=lambda **kwargs: self._train_batch(microbatches, **kwargs).item()) else: total_loss = self._train_batch(microbatches) for optimizer in state.optimizers: