Skip to content

Commit

Permalink
quick patch __code__ (#1352)
Browse files Browse the repository at this point in the history
* quick patch

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix
  • Loading branch information
williamFalcon committed Apr 3, 2020
1 parent 1576ad9 commit 2eca8a9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@ def is_overriden(self, method_name: str, model: LightningModule = None) -> bool:
# in case of calling deprecated method
return False

# when code pointers are different, it was overriden
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
instance_attr = getattr(model, method_name)
super_attr = getattr(super_object, method_name)

# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overriden = instance_attr.__code__ is not super_attr.__code__
return is_overriden

def has_arg(self, f_name, arg_name):
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,10 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader

0 comments on commit 2eca8a9

Please sign in to comment.