Skip to content

Commit

Permalink
Actually show deprecation warnings and their line level [2/2] (#8002)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
awaelchli authored and lexierule committed Jun 22, 2021
1 parent 662b6d6 commit 55271b1
Show file tree
Hide file tree
Showing 19 changed files with 154 additions and 74 deletions.
5 changes: 1 addition & 4 deletions .azure-pipelines/ipu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,9 @@ jobs:
export GIT_TERMINAL_PROMPT=1
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
python ./requirements/adjust_versions.py requirements/extra.txt
python ./requirements/adjust_versions.py requirements/examples.txt
pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed
pip install . --requirement requirements/devel.txt
pip list
displayName: 'Install dependencies'
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torch.nn import Module

from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.grads import grad_norm as new_grad_norm


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only

log = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import _module_available, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only

_TESTTUBE_AVAILABLE = _module_available("test_tube")

Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities import (
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_8:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning

if _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.autograd.profiler import record_function

from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE

Expand Down Expand Up @@ -351,7 +351,7 @@ def __deprecation_check(
if profiled_functions is not None:
rank_zero_deprecation(
"`PyTorchProfiler.profiled_functions` has been renamed to"
" `record_functions` in v1.3 and will be removed in v1.5",
" `record_functions` in v1.3 and will be removed in v1.5"
)
if not record_functions:
record_functions |= set(profiled_functions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@
device_parser,
DeviceType,
DistributedType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _HOROVOD_AVAILABLE:
Expand Down
44 changes: 21 additions & 23 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,28 @@ def attach_dataloaders(
def attach_datamodule(
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer

# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline
if datamodule is None:
return

# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer

# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline


class _PatchDataLoader(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import ABC

from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.metrics import metrics_to_scalars as new_metrics_to_scalars


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Optional

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature


class TrainerModelHooksMixin(ABC):
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@
import numpy

from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
DeviceType,
Expand Down Expand Up @@ -59,6 +53,7 @@
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
29 changes: 17 additions & 12 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import logging
import os
import warnings
from functools import partial, wraps
from functools import wraps
from platform import python_version
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -65,22 +65,26 @@ def _get_rank() -> int:
rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank())


def _warn(*args, **kwargs):
warnings.warn(*args, **kwargs)


def _info(*args, **kwargs):
def _info(*args, stacklevel: int = 2, **kwargs):
if python_version() >= "3.8.0":
kwargs['stacklevel'] = stacklevel
log.info(*args, **kwargs)


def _debug(*args, **kwargs):
def _debug(*args, stacklevel: int = 2, **kwargs):
if python_version() >= "3.8.0":
kwargs['stacklevel'] = stacklevel
log.debug(*args, **kwargs)


rank_zero_debug = rank_zero_only(_debug)
rank_zero_info = rank_zero_only(_info)
rank_zero_warn = rank_zero_only(_warn)
rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning)
@rank_zero_only
def rank_zero_debug(*args, stacklevel: int = 4, **kwargs):
_debug(*args, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_info(*args, stacklevel: int = 4, **kwargs):
_info(*args, stacklevel=stacklevel, **kwargs)


def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
Expand Down Expand Up @@ -294,6 +298,7 @@ def register_ddp_comm_hook(
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
"""
from pytorch_lightning.utilities import rank_zero_warn
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argparse import Namespace
from typing import Any, Dict, Optional, Sequence, Tuple, Union

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.warnings import rank_zero_warn


def str_to_bool_or_str(val: str) -> Union[str, bool]:
Expand Down Expand Up @@ -97,7 +97,7 @@ def clean_namespace(hparams):
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]

for k in del_attrs:
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
del hparams_dict[k]


Expand Down
45 changes: 31 additions & 14 deletions pytorch_lightning/utilities/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
"""Warning-related utilities"""
import warnings
from functools import partial

from pytorch_lightning.utilities.distributed import rank_zero_only

class WarningCache:

def __init__(self):
self.warnings = set()
def _warn(*args, stacklevel: int = 2, **kwargs):
warnings.warn(*args, stacklevel=stacklevel, **kwargs)

def clear(self):
self.warnings.clear()

def warn(self, m, *args, **kwargs):
if m not in self.warnings:
self.warnings.add(m)
rank_zero_warn(m, *args, **kwargs)
@rank_zero_only
def rank_zero_warn(*args, stacklevel: int = 4, **kwargs):
_warn(*args, stacklevel=stacklevel, **kwargs)

def deprecation(self, m, *args, **kwargs):
if m not in self.warnings:
self.warnings.add(m)
rank_zero_deprecation(m, *args, **kwargs)

class LightningDeprecationWarning(DeprecationWarning):
...


# enable our warnings
warnings.simplefilter('default', LightningDeprecationWarning)

rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)


class WarningCache(set):

def warn(self, m, *args, stacklevel: int = 5, **kwargs):
if m not in self:
self.add(m)
rank_zero_warn(m, *args, stacklevel=stacklevel, **kwargs)

def deprecation(self, m, *args, stacklevel: int = 5, **kwargs):
if m not in self:
self.add(m)
rank_zero_deprecation(m, *args, stacklevel=stacklevel, **kwargs)
6 changes: 6 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ done

nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx

# needs to run outside of `pytest`
python tests/utilities/test_warnings.py
if [ $? -eq 0 ]; then
report+="Ran\ttests/utilities/test_warnings.py\n"
fi

# echo test report
printf '=%.s' {1..80}
printf "\n$report"
Expand Down
Loading

0 comments on commit 55271b1

Please sign in to comment.