Skip to content

Commit

Permalink
update exception message
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Aug 6, 2020
1 parent ff7353e commit cf3935c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ def __init__(
)
limit_train_batches = train_percent_check

self.limit_test_batches = _determine_limit_batches(limit_test_batches)
self.limit_val_batches = _determine_limit_batches(limit_val_batches)
self.limit_train_batches = _determine_limit_batches(limit_train_batches)
self.val_check_interval = _determine_limit_batches(val_check_interval)
self.overfit_batches = _determine_limit_batches(overfit_batches)
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.overfit_batches)

# AMP init
Expand Down Expand Up @@ -1428,12 +1428,12 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _determine_limit_batches(batches: Union[int, float]) -> Union[int, float]:
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches}, it has to be in (0, 1) or an int.'
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)

0 comments on commit cf3935c

Please sign in to comment.