Skip to content

Commit

Permalink
added setting to skip x epochs when using train_and_test_on_datasets (#…
Browse files Browse the repository at this point in the history
…258)

Co-authored-by: Dref360 <fred@glowstick.cx>
Co-authored-by: Frédéric Branchaud-Charron <frederic.branchaud.charron@gmail.com>
  • Loading branch information
3 people authored Jul 1, 2023
1 parent 81e248c commit def01e3
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions baal/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def train_and_test_on_datasets(
return_best_weights=False,
patience=None,
min_epoch_for_es=0,
skip_epochs=1,
):
"""
Train and test the model on both Dataset `train_dataset`, `test_dataset`.
Expand All @@ -166,6 +167,7 @@ def train_and_test_on_datasets(
patience (Optional[int]): If provided, will use early stopping to stop after
`patience` epoch without improvement.
min_epoch_for_es (int): Epoch at which the early stopping starts.
skip_epochs (int): Number of epochs to skip for test_on_dataset
Returns:
History and best weights if required.
Expand All @@ -178,17 +180,22 @@ def train_and_test_on_datasets(
_ = self.train_on_dataset(
train_dataset, optimizer, batch_size, 1, use_cuda, workers, collate_fn, regularizer
)
te_loss = self.test_on_dataset(test_dataset, batch_size, use_cuda, workers, collate_fn)
hist.append(self.get_metrics())
if te_loss < best_loss:
best_epoch = e
best_loss = te_loss
if return_best_weights:
best_weight = deepcopy(self.state_dict())

if patience is not None and (e - best_epoch) > patience and (e > min_epoch_for_es):
# Early stopping
break
if e % skip_epochs == 0:
te_loss = self.test_on_dataset(
test_dataset, batch_size, use_cuda, workers, collate_fn
)
hist.append(self.get_metrics())
if te_loss < best_loss:
best_epoch = e
best_loss = te_loss
if return_best_weights:
best_weight = deepcopy(self.state_dict())

if patience is not None and (e - best_epoch) > patience and (e > min_epoch_for_es):
# Early stopping
break
else:
hist.append(self.get_metrics("train"))

if return_best_weights:
return hist, best_weight
Expand Down

0 comments on commit def01e3

Please sign in to comment.