Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quick patch __code__ #1352

Merged
merged 14 commits into from
Apr 3, 2020
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@ def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
# in case of calling deprecated method
return False

# when code pointers are different, it was overriden
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
instance_attr = getattr(model, method_name)
super_attr = getattr(super_object, method_name)

# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overriden = instance_attr.__code__ is not super_attr.__code__
return is_overriden

def has_arg(self, f_name, arg_name):
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,10 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader