Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix result obj dp auto reduce #3013

Merged
merged 13 commits into from
Aug 17, 2020
8 changes: 8 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ def reduce_across_time(cls, time_outputs):
result['meta'] = meta
return result

def dp_reduce(self):
for k, value in self.items():
if k == 'meta':
continue
if isinstance(value, list):
value = torch.tensor(value)
self[k] = value.mean(dim=-1)

@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
Expand Down
23 changes: 13 additions & 10 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,20 @@ def _evaluate(
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)

# ------------------
# EVAL STEP END
# ------------------
# on dp / ddp2 might still want to do something with the batch parts
if test_mode:
if self.is_overridden('test_step_end'):
model_ref = self.get_model()
with self.profiler.profile('test_step_end'):
output = model_ref.test_step_end(output)
else:
if self.is_overridden('validation_step_end'):
model_ref = self.get_model()
with self.profiler.profile('validation_step_end'):
output = model_ref.validation_step_end(output)
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
if self.is_overridden(eval_step_end_hook_name):
model_ref = self.get_model()
with self.profiler.profile(eval_step_end_hook_name):
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
output = eval_step_end(output)

elif is_result_obj and (self.use_dp or self.use_ddp2):
# result auto reduce
output.dp_reduce()

# callbacks (on __batch_end)
if test_mode:
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
else:
output = self.model.training_step(*args)

is_result_obj = isinstance(output, Result)

# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden('training_step_end'):
Expand All @@ -1229,6 +1231,9 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
# TODO: modify when using result obj
output = model_ref.training_step_end(output)

elif is_result_obj and (self.use_dp or self.use_ddp2):
output.dp_reduce()

# allow any mode to define training_end
# TODO: remove in 1.0.0
if self.is_overridden('training_end'):
Expand Down
22 changes: 22 additions & 0 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=
self.training_step_called = True
return result

def training_step_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))

# calculate loss
loss_val = self.loss(y.to(y_hat.device), y_hat)
log_val = loss_val

# alternate between tensors and scalars for "log" and "progress_bar"
if batch_idx % 2 == 0:
log_val = log_val.item()

result = TrainResult(loss_val)
result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
result.log('train_some_val', log_val * log_val)

self.training_step_called = True

return result

def training_step_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)
Expand Down
22 changes: 22 additions & 0 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
})
return result

def validation_step_result_obj_dp(self, batch, batch_idx, *args, **kwargs):
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))

y = y.to(y_hat.device)
loss_val = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)

result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
result.log_dict({
'val_loss': loss_val,
'val_acc': val_acc,
})

self.validation_step_called = True
return result

def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Lightning calls this inside the validation loop
Expand Down
35 changes: 35 additions & 0 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,41 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
assert 'epoch_train_epoch_end_metric' in seen_keys


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_loop_steps_only_dp(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'

batches = 10
epochs = 3

model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_result_obj_dp
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj_dp
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dp',
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
)

trainer.fit(model)

assert model.training_step_called
assert model.validation_step_called


def test_result_map(tmpdir):
result = TrainResult()
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
Expand Down