Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually show deprecation warnings and their line level [2/2] #8002

Merged
merged 20 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.distributed import rank_zero_deprecation


class LightningDataModule(CheckpointHooks, DataHooks):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torch.nn import Module

from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.grads import grad_norm as new_grad_norm


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def datamodule(self) -> Any:
warning_cache.deprecation(
"The `LightningModule.datamodule` property is deprecated in v1.3 and will be removed in v1.5."
" Access the datamodule through using `self.trainer.datamodule` instead.",
stacklevel=5,
stacklevel=6,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
return self._datamodule

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only

log = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import _module_available, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only

_TESTTUBE_AVAILABLE = _module_available("test_tube")

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Te
self.warning_cache.deprecation(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5",
" the old signature will be removed in v1.5"
)
args.append(opt_idx)
elif not self.trainer.has_arg(
Expand Down Expand Up @@ -685,7 +685,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio
self.warning_cache.deprecation(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5",
" the old signature will be removed in v1.5"
)
step_kwargs['optimizer_idx'] = opt_idx
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities import (
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_deprecation,
rank_zero_only,
rank_zero_warn,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_8:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _warn, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning

if _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down Expand Up @@ -263,7 +264,7 @@ def __init__(
"The usage of `cpu_offload`, `cpu_offload_params`, and `cpu_offload_use_pin_memory` "
"is deprecated since v1.4 and will be removed in v1.5."
" From now on use `offload_optimizer`, `offload_parameters` and `pin_memory`.",
category=DeprecationWarning
category=LightningDeprecationWarning
)
offload_optimizer = cpu_offload
offload_parameters = cpu_offload_params
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation

rank_zero_deprecation(
"Using ``import pytorch_lightning.profiler.profilers`` is depreceated in v1.4, and will be removed in v1.6. "
"Using ``import pytorch_lightning.profiler.profilers`` is deprecated in v1.4, and will be removed in v1.6. "
"HINT: Use ``import pytorch_lightning.profiler`` directly."
)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.autograd.profiler import record_function

from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE

Expand Down Expand Up @@ -351,7 +351,7 @@ def __deprecation_check(
if profiled_functions is not None:
rank_zero_deprecation(
"`PyTorchProfiler.profiled_functions` has been renamed to"
" `record_functions` in v1.3 and will be removed in v1.5",
" `record_functions` in v1.3 and will be removed in v1.5"
)
if not record_functions:
record_functions |= set(profiled_functions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@
device_parser,
DeviceType,
DistributedType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _HOROVOD_AVAILABLE:
Expand Down
44 changes: 21 additions & 23 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,28 @@ def attach_dataloaders(
def attach_datamodule(
self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None
) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer

# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline
if datamodule is None:
return

# Override loader hooks
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer

# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline


class _PatchDataLoader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,10 @@ def extra(self, extra: Mapping[str, Any]) -> None:

def check_fn(v):
if v.grad_fn is not None:
warning_cache.warn(
warning_cache.deprecation(
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning
" `return {'loss': ..., 'something': something.detach()}`"
)
return v.detach()
return v
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Dict, List, Optional, Union

from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import ABC

from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.metrics import metrics_to_scalars as new_metrics_to_scalars


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature


Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@
import numpy

from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
DeviceType,
Expand Down Expand Up @@ -63,6 +57,7 @@
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
29 changes: 17 additions & 12 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import logging
import os
import warnings
from functools import partial, wraps
from functools import wraps
from platform import python_version
from typing import Any, Optional, Union

import torch
Expand Down Expand Up @@ -65,22 +65,26 @@ def _get_rank() -> int:
rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank())


def _warn(*args, **kwargs):
warnings.warn(*args, **kwargs)


def _info(*args, **kwargs):
def _info(*args, stacklevel: int = 2, **kwargs):
if python_version() >= "3.8.0":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
kwargs['stacklevel'] = stacklevel
log.info(*args, **kwargs)


def _debug(*args, **kwargs):
def _debug(*args, stacklevel: int = 2, **kwargs):
if python_version() >= "3.8.0":
kwargs['stacklevel'] = stacklevel
log.debug(*args, **kwargs)


rank_zero_debug = rank_zero_only(_debug)
rank_zero_info = rank_zero_only(_info)
rank_zero_warn = rank_zero_only(_warn)
rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning)
@rank_zero_only
def rank_zero_debug(*args, stacklevel: int = 4, **kwargs):
_debug(*args, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_info(*args, stacklevel: int = 4, **kwargs):
_info(*args, stacklevel=stacklevel, **kwargs)


def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
Expand Down Expand Up @@ -294,6 +298,7 @@ def register_ddp_comm_hook(
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
"""
from pytorch_lightning.utilities import rank_zero_warn
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.warnings import rank_zero_warn


def str_to_bool_or_str(val: str) -> Union[str, bool]:
Expand Down Expand Up @@ -98,7 +98,7 @@ def clean_namespace(hparams):
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]

for k in del_attrs:
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
del hparams_dict[k]


Expand Down
32 changes: 27 additions & 5 deletions pytorch_lightning/utilities/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
import warnings
from functools import partial

from pytorch_lightning.utilities.distributed import rank_zero_only


def _warn(*args, stacklevel: int = 2, **kwargs):
warnings.warn(*args, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_warn(*args, stacklevel: int = 4, **kwargs):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_warn(*args, stacklevel=stacklevel, **kwargs)


class LightningDeprecationWarning(DeprecationWarning):
...


# enable our warnings
warnings.simplefilter('default', LightningDeprecationWarning)

rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)


class WarningCache(set):

def warn(self, m, *args, **kwargs):
def warn(self, m, *args, stacklevel: int = 5, **kwargs):
if m not in self:
self.add(m)
rank_zero_warn(m, *args, **kwargs)
rank_zero_warn(m, *args, stacklevel=stacklevel, **kwargs)

def deprecation(self, m, *args, **kwargs):
def deprecation(self, m, *args, stacklevel: int = 5, **kwargs):
if m not in self:
self.add(m)
rank_zero_deprecation(m, *args, **kwargs)
rank_zero_deprecation(m, *args, stacklevel=stacklevel, **kwargs)
6 changes: 6 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ if nvcc --version; then
nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx
fi

# needs to run outside of `pytest`
python tests/utilities/test_warnings.py
if [ $? -eq 0 ]; then
report+="Ran\ttests/utilities/test_warnings.py\n"
fi

# echo test report
printf '=%.s' {1..80}
printf "\n$report"
Expand Down
Loading