diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 13126ba919..8f6c3974de 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -295,7 +295,7 @@ def _get_initial_device_train_microbatch_size( "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", ) from e return batch_size - elif isinstance(device_train_microbatch_size, Union[int, float]): + elif isinstance(device_train_microbatch_size, (int, float)): return device_train_microbatch_size else: raise ValueError("device_train_microbatch_size must be an int or ``'auto'``")