diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 177aff1faaf03f..d4871ff2158b60 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -20,8 +20,17 @@ def is_overriden(self, method_name: str, model: LightningModule = None) -> bool: # in case of calling deprecated method return False - # when code pointers are different, it was overriden - is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__ + instance_attr = getattr(model, method_name) + super_attr = getattr(super_object, method_name) + + # when code pointers are different, it was implemented + if hasattr(instance_attr, 'patch_loader_code'): + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overriden = instance_attr.__code__ is not super_attr.__code__ return is_overriden def has_arg(self, f_name, arg_name): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 290f2acf403f8b..e346f1f499700b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -971,8 +971,10 @@ class _PatchDataLoader(object): def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader - # Assign __code__, needed for checking if method has been overriden - self.__code__ = self.__call__.__code__ + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + self.patch_loader_code = str(self.__call__.__code__) def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader