diff --git a/nemo_aligner/utils/trainer_utils.py b/nemo_aligner/utils/trainer_utils.py index b6a17f0d6..524763077 100644 --- a/nemo_aligner/utils/trainer_utils.py +++ b/nemo_aligner/utils/trainer_utils.py @@ -29,7 +29,7 @@ def compute_num_steps_per_epoch( num_steps_per_epoch = sampler.total_samples // sampler.global_batch_size - if limit_train_batches is None or limit_train_batches > 1.0: + if limit_train_batches is None or (isinstance(limit_train_batches, float) and limit_train_batches > 1.0): limit_train_batches = 1.0 if limit_train_batches >= 0: