Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[warning] Add warning when values are not being reduced #6417

Merged
merged 14 commits into from
Mar 26, 2021
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))


## [1.2.3] - 2021-03-09

### Fixed
Expand All @@ -152,8 +155,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Resolve memory leak for evaluation ([#6326](https://github.com/PyTorchLightning/pytorch-lightning/pull/6326)
- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330)
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))


Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ def rename_keys(self, map_dict: dict):
meta[dest] = meta[source]
del meta[source]

def get_non_metrics_keys(self):
"""
This function is used to filter metric keys for which the value isn't a Metric
"""
return [k for k, v in self.items() if not isinstance(v, Metric)]


def choose_last(x):
if isinstance(x, (torch.Tensor, list)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from weakref import proxy

import torch

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import DistributedType, LightningEnum
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)


class MetricWarningCache(WarningCache):

def __init__(self):
super().__init__()
self.warned_metrics = []


warning_cache = MetricWarningCache()


class ResultStoreType(LightningEnum):
Expand Down Expand Up @@ -50,8 +66,9 @@ class HookResultStore:
Those data structures enables us to reduce properly Result object when batch loop is finished.
"""

def __init__(self, fx_name):
def __init__(self, fx_name: str, all_gather_fn: Callable) -> None:
self._fx_name = fx_name
self._all_gather_fn = all_gather_fn
self._internals = {}
self._internals_reduced = {}
self._internal_type = None
Expand Down Expand Up @@ -104,8 +121,26 @@ def get_batch_log_metrics(self, *args, **kwargs):
def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None:
if not isinstance(opt_metric, Result):
raise Exception("The provided opt_metric should be a Result Object. Something is wrong")

func = getattr(opt_metric, func_name)
metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs)
if (
torch.distributed.is_available() and torch.distributed.is_initialized()
and self._all_gather_fn.__self__.trainer.world_size > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really the easiest way to check the world size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but I had some issues with windows using get_world_size utils from torch. Any idea ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to find a better way.

):
for non_metric_key in opt_metric.get_non_metrics_keys():
if non_metric_key in metrics_to_log and non_metric_key not in warning_cache.warned_metrics:
metric = self._all_gather_fn(metrics_to_log[non_metric_key])
if any(metric[0] != m for m in metric[1:]):
warning_cache.warn(
f"The value associated to the key {non_metric_key}: {metric.cpu().tolist()} "
"doesn't appear to be the same accross all processes. "
"HINT: One could either do: `self.log(..., sync_dist=True, sync_fn=torch.mean)`"
" to force mean reduction across processes which can be inaccurate or implement"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
" a `pytorch_lightning.metrics.Metric`"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
warning_cache.warned_metrics.append(non_metric_key)

results.append(metrics_to_log)

def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
Expand Down Expand Up @@ -222,8 +257,8 @@ class EpochResultStore:
```
"""

def __init__(self, trainer) -> None:
self.trainer = trainer
def __init__(self, trainer: 'pl.Trainer') -> None:
self.trainer = proxy(trainer)
self.reset()

def __getitem__(self, key: str) -> Any:
Expand Down Expand Up @@ -275,7 +310,8 @@ def cache_result(self) -> None:
info = self.info
fx_name = info["fx_name"]

self._internals.setdefault(fx_name, HookResultStore(fx_name))
all_gather_fn = self.trainer.lightning_module.all_gather
self._internals.setdefault(fx_name, HookResultStore(fx_name, all_gather_fn))

# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)
Expand Down
7 changes: 6 additions & 1 deletion tests/trainer/logging_/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ class TestLoggingSyncDistModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
self.log('cho', acc, on_step=False, on_epoch=True)
return acc

def validation_step(self, batch, batch_idx):
Expand All @@ -763,8 +764,12 @@ def validation_step(self, batch, batch_idx):
gpus=2,
profiler="pytorch"
)
trainer.fit(model)

if os.getenv("LOCAL_RANK") == '0':
with pytest.warns(UserWarning, match="The value associated to the key cho:"):
trainer.fit(model)
else:
trainer.fit(model)
assert trainer.logged_metrics['foo'] == 2
assert trainer.logged_metrics['bar'] == 2

Expand Down