Skip to content

Commit

Permalink
🐛 fix add_argparse_args
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 23, 2020
1 parent 0f10c54 commit 4c4b734
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@


class _DataModuleWrapper(type):

def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:
1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
3. Runs prepare_data()
4. Runs setup()
"""

# Get instance of LightningDataModule by mocking its __init__ via __call__
Expand All @@ -23,9 +22,6 @@ def __call__(cls, *args, **kwargs):
# Wrap instance's prepare_data function with rank_zero_only and reassign to instance
obj.prepare_data = rank_zero_only(obj.prepare_data)

# Run both prepare_data() and setup() post-init
obj.prepare_data()
obj.setup()
return obj


Expand Down Expand Up @@ -278,7 +274,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
"""
datamodule_default_params = inspect.signature(cls).parameters
datamodule_default_params = inspect.signature(cls.__init__).parameters
name_type_default = []
for arg in datamodule_default_params:
arg_type = datamodule_default_params[arg].annotation
Expand Down

0 comments on commit 4c4b734

Please sign in to comment.