Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarang committed Jun 9, 2020
1 parent 64e646e commit e244185
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
13 changes: 11 additions & 2 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):

# calculate loss
loss_val = self.loss(y, y_hat)
loss_scalar = loss_val.item()

# alternate possible outputs to test
if batch_idx % 2 == 0:
output = OrderedDict({
'loss': loss_val,
'progress_bar': {'some_val': loss_val * loss_val},
'log': {'train_some_val': loss_val * loss_val},
})

# return scalars for "log" and "progress_bar"
output = OrderedDict({
'loss': loss_val,
'progress_bar': {'some_val': loss_val * loss_val},
'log': {'train_some_val': loss_val * loss_val},
'progress_bar': {'some_val': loss_scalar * loss_scalar},
'log': {'train_some_val': loss_scalar * loss_scalar},
})
return output

Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""

def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Expand All @@ -23,7 +24,11 @@ def _mean(res, key):
val_loss_mean = _mean(outputs, 'val_loss')
val_acc_mean = _mean(outputs, 'val_acc')

metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
# alternate between tensor and scalar
if self.current_epoch % 2:
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
else:
metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

Expand Down

0 comments on commit e244185

Please sign in to comment.