From dbbeb96c6b956c49f68678daf9bef915747a39f1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 25 Apr 2020 23:40:44 -0400 Subject: [PATCH 01/13] fixed distutil parsing --- pytorch_lightning/trainer/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c067a6dc7f834..064b96701d641 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,4 +1,3 @@ -import distutils import inspect import os from argparse import ArgumentParser @@ -609,7 +608,13 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: for allowed_type in (at for at in allowed_types if at in arg_types): if allowed_type is bool: def allowed_type(x): - return bool(distutils.util.strtobool(x)) + # distutils.util.strtobool() has issues + if x.lower() == 'true': + return True + elif x.lower() == 'false': + return False + else: + raise Exception('bool not specified') if arg == 'gpus': def allowed_type(x): From 3c09f3166440aec6856ddfa2f3f51799400ee5fa Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 25 Apr 2020 23:41:14 -0400 Subject: [PATCH 02/13] fixed distutil parsing --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 064b96701d641..89b0078041232 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -598,7 +598,8 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: """ parser = ArgumentParser(parents=[parent_parser], add_help=False, ) - depr_arg_names = cls.get_deprecated_arg_names() + blacklist = ['kwargs'] + depr_arg_names = cls.get_deprecated_arg_names() + blacklist allowed_types = (str, float, int, bool) # TODO: get "help" from docstring :) From 1db9884bd7d0e6282f1b6a1b4cc53fb8638f49ab Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 26 Apr 2020 10:04:35 +0200 Subject: [PATCH 03/13] Apply suggestions from code review --- pytorch_lightning/trainer/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 89b0078041232..07e7ba41331a2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -610,11 +610,10 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: if allowed_type is bool: def allowed_type(x): # distutils.util.strtobool() has issues - if x.lower() == 'true': - return True - elif x.lower() == 'false': - return False - else: + try: + # convert any true/TRUE/false/FALSE + return eval(x.lower().capitalize()) + except NameError: raise Exception('bool not specified') if arg == 'gpus': From 3690e32d43c4178d34ce6605b49162a146286472 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Sun, 26 Apr 2020 10:06:07 +0200 Subject: [PATCH 04/13] log --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f5cebf89aae2..61eeb4352b5e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixes automatic parser bug ([#1585](https://github.com/PyTorchLightning/pytorch-lightning/issues/1585)) +- Fixed bool conversion from string ([#1606](https://github.com/PyTorchLightning/pytorch-lightning/issues/1606)) + ## [0.7.3] - 2020-04-09 ### Added From e70b03c1c6d657dfb5095eee2477a67e94cbbb31 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 06:54:31 -0400 Subject: [PATCH 05/13] fixed distutil parsing --- pytorch_lightning/trainer/trainer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 89b0078041232..19130f7289858 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -32,6 +32,7 @@ from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import parsing_utils try: @@ -602,6 +603,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: depr_arg_names = cls.get_deprecated_arg_names() + blacklist allowed_types = (str, float, int, bool) + # TODO: get "help" from docstring :) for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names): @@ -609,13 +611,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: for allowed_type in (at for at in allowed_types if at in arg_types): if allowed_type is bool: def allowed_type(x): - # distutils.util.strtobool() has issues - if x.lower() == 'true': - return True - elif x.lower() == 'false': - return False - else: - raise Exception('bool not specified') + return bool(parsing_utils.strtobool(x)) if arg == 'gpus': def allowed_type(x): From fe68a3b63244fa862f746f03dbcb2fa3ec3b4159 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 06:55:34 -0400 Subject: [PATCH 06/13] fixed distutil parsing --- pytorch_lightning/utilities/parsing_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 pytorch_lightning/utilities/parsing_utils.py diff --git a/pytorch_lightning/utilities/parsing_utils.py b/pytorch_lightning/utilities/parsing_utils.py new file mode 100644 index 0000000000000..f3175860c7d7e --- /dev/null +++ b/pytorch_lightning/utilities/parsing_utils.py @@ -0,0 +1,14 @@ +def strtobool (val): + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ('y', 'yes', 't', 'true', 'on', '1'): + return 1 + elif val in ('n', 'no', 'f', 'false', 'off', '0'): + return 0 + else: + raise ValueError("invalid truth value %r" % (val,)) From 8928d47de481a68210d1174cac411c38056b3fd6 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 07:01:34 -0400 Subject: [PATCH 07/13] fixed distutil parsing --- pytorch_lightning/utilities/parsing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/parsing_utils.py b/pytorch_lightning/utilities/parsing_utils.py index f3175860c7d7e..3f47e54c62617 100644 --- a/pytorch_lightning/utilities/parsing_utils.py +++ b/pytorch_lightning/utilities/parsing_utils.py @@ -1,5 +1,6 @@ -def strtobool (val): +def strtobool(val): """Convert a string representation of truth to true (1) or false (0). + Copied from python True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if @@ -11,4 +12,4 @@ def strtobool (val): elif val in ('n', 'no', 'f', 'false', 'off', '0'): return 0 else: - raise ValueError("invalid truth value %r" % (val,)) + raise ValueError(f'invalid truth value {val}') From cfac4d26ba0ffc772e2ca65b03759a5cc67b5003 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 07:01:58 -0400 Subject: [PATCH 08/13] fixed distutil parsing --- pytorch_lightning/utilities/parsing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/parsing_utils.py b/pytorch_lightning/utilities/parsing_utils.py index 3f47e54c62617..7de79710198c8 100644 --- a/pytorch_lightning/utilities/parsing_utils.py +++ b/pytorch_lightning/utilities/parsing_utils.py @@ -1,6 +1,6 @@ def strtobool(val): """Convert a string representation of truth to true (1) or false (0). - Copied from python + Copied from the python implementation distutils.utils.strtobool True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if From 9bb44f0d8a28a31ce8de4a5b7d64223735390f34 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Sun, 26 Apr 2020 13:35:01 +0200 Subject: [PATCH 09/13] doctest --- pytorch_lightning/utilities/{memory_utils.py => memory.py} | 0 pytorch_lightning/utilities/{parsing_utils.py => parsing.py} | 5 +++++ 2 files changed, 5 insertions(+) rename pytorch_lightning/utilities/{memory_utils.py => memory.py} (100%) rename pytorch_lightning/utilities/{parsing_utils.py => parsing.py} (89%) diff --git a/pytorch_lightning/utilities/memory_utils.py b/pytorch_lightning/utilities/memory.py similarity index 100% rename from pytorch_lightning/utilities/memory_utils.py rename to pytorch_lightning/utilities/memory.py diff --git a/pytorch_lightning/utilities/parsing_utils.py b/pytorch_lightning/utilities/parsing.py similarity index 89% rename from pytorch_lightning/utilities/parsing_utils.py rename to pytorch_lightning/utilities/parsing.py index 7de79710198c8..26fc410dafdb6 100644 --- a/pytorch_lightning/utilities/parsing_utils.py +++ b/pytorch_lightning/utilities/parsing.py @@ -5,6 +5,11 @@ def strtobool(val): True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 'val' is anything else. + + >>> strtobool('YES') + 1 + >>> strtobool('FALSE') + 0 """ val = val.lower() if val in ('y', 'yes', 't', 'true', 'on', '1'): From 2301a04427dfa773a4c95ac44782e919fdb0522f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 07:44:23 -0400 Subject: [PATCH 10/13] fixed hparams section --- docs/source/hyperparameters.rst | 183 ++++++++++++++++----------- pytorch_lightning/trainer/trainer.py | 4 +- 2 files changed, 112 insertions(+), 75 deletions(-) diff --git a/docs/source/hyperparameters.rst b/docs/source/hyperparameters.rst index f9f36448223d2..7a20b5f698e99 100644 --- a/docs/source/hyperparameters.rst +++ b/docs/source/hyperparameters.rst @@ -3,36 +3,111 @@ Hyperparameters Lightning has utilities to interact seamlessly with the command line ArgumentParser and plays well with the hyperparameter optimization framework of your choice. -LightiningModule hparams +ArgumentParser +^^^^^^^^^^^^^^ +Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser + +.. code-block:: python + + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument('--layer_1_dim', type=int, default=128) + + args = parser.parse_args() + +This allows you to call your program like so: + +.. code-block:: bash + + python trainer.py --layer_1_dim 64 + + +Argparser Best Practices ^^^^^^^^^^^^^^^^^^^^^^^^ +It is best practice to layer your arguments in three sections. -Normally, we don't hard-code the values to a model. We usually use the command line to -modify the network. The `Trainer` can add all the available options to an ArgumentParser. + 1. Trainer args (gpus, num_nodes, etc...) + 2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...) + 3. Program arguments (data_path, cluster_email, etc...) + +We can do this as follows. First, in your LightningModule, define the arguments +specific to that module. Remember that data splits or data paths may also be specific to +a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10). .. code-block:: python + class LitModel(LightningModule): + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--encoder_layers', type=int, default=12) + parser.add_argument('--data_path', type=str, default='/some/path') + return parser + +Now in your main trainer file, add the Trainer args, the program args, and add the model args + +.. code-block:: python + + # ---------------- + # trainer_main.py + # ---------------- from argparse import ArgumentParser parser = ArgumentParser() - # parametrize the network - parser.add_argument('--layer_1_dim', type=int, default=128) - parser.add_argument('--layer_2_dim', type=int, default=256) - parser.add_argument('--batch_size', type=int, default=64) + # add PROGRAM level args + parser.add_argument('--conda_env', type=str, default='some_name') + parser.add_argument('--notification_email', type=str, default='will@email.com') + + # add model specific args + parser = LitModel.add_model_specific_args(parser) - # add all the available options to the trainer + # add all the available trainer options to argparse + # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() + hparams = parser.parse_args() -Now we can parametrize the LightningModule. +Now you can call run your program like so + +.. code-block:: bash + + python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12 + +Finally, make sure to start the training like so: + +.. code-block:: bash + + hparams = parser.parse_args() + + # YES + model = LitModel(hparams) + + # NO + # model = LitModel(learning_rate=hparams.learning_rate, ...) + + # YES + trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...) + + # NO + trainer = Trainer(gpus=hparams.gpus, ...) + + +LightiningModule hparams +^^^^^^^^^^^^^^^^^^^^^^^^ + +Normally, we don't hard-code the values to a model. We usually use the command line to +modify the network and read those values in the LightningModule .. code-block:: python - :emphasize-lines: 5,6,7,12,14 class LitMNIST(pl.LightningModule): def __init__(self, hparams): super().__init__() + + # do this to save all arguments in any logger (tensorboard) self.hparams = hparams self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) @@ -49,86 +124,44 @@ Now we can parametrize the LightningModule. def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.learning_rate) - hparams = parse_args() - model = LitMNIST(hparams) + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) -.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your - model using those hparams exactly. + parser.add_argument('--layer_1_dim', type=int, default=128) + parser.add_argument('--layer_2_dim', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--learning_rate', type=float, default=0.002) + return parser -And we can also add all the flags available in the Trainer to the Argparser. +Now pass in the params when you init your model .. code-block:: python - # add all the available Trainer options to the ArgParser - parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() - -And now you can start your program with + hparams = parse_args() + model = LitMNIST(hparams) -.. code-block:: bash +The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule. +This does two things: - # now you can use any trainer flag - $ python main.py --num_nodes 2 --gpus 8 + 1. It adds them automatically to tensorboard logs under the hparams tab. + 2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly. Trainer args ^^^^^^^^^^^^ - -It also gets annoying to map each argument into the Argparser. Luckily we have -a default parser +To recap, add ALL possible trainer flags to the argparser and init the Trainer this way .. code-block:: python parser = ArgumentParser() - - # add all options available in the trainer such as (max_epochs, etc...) parser = Trainer.add_argparse_args(parser) + hparams = parser.parse_args() -We set up the main training entry point file like this: - -.. code-block:: python - - def main(args): - model = LitMNIST(hparams=args) - trainer = Trainer(max_epochs=args.max_epochs) - trainer.fit(model) + trainer = Trainer.from_argparse_args(hparams) - if __name__ == '__main__': - parser = ArgumentParser() + # or if you need to pass in callbacks + trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...]) - # adds all the trainer options as default arguments (like max_epochs) - parser = Trainer.add_argparse_args(parser) - - # parametrize the network - parser.add_argument('--layer_1_dim', type=int, default=128) - parser.add_argument('--layer_1_dim', type=int, default=256) - parser.add_argument('--batch_size', type=int, default=64) - args = parser.parse_args() - - # train - main(args) - -And now we can train like this: - -.. code-block:: bash - - $ python main.py --layer_1_dim 128 --layer_2_dim 256 --batch_size 64 --max_epochs 64 - -But it would also be nice to pass in any arbitrary argument to the trainer. -We can do it by changing how we init the trainer. - -.. code-block:: python - - def main(args): - model = LitMNIST(hparams=args) - - # makes all trainer options available from the command line - trainer = Trainer.from_argparse_args(args) - -and now we can do this: - -.. code-block:: bash - - $ python main.py --gpus 1 --min_epochs 12 --max_epochs 64 --arbitrary_trainer_arg some_value Multiple Lightning Modules ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -173,7 +206,7 @@ Now we can allow each model to inject the arguments it needs in the main.py model = LitMNIST(hparams=args) model = LitMNIST(hparams=args) - trainer = Trainer(max_epochs=args.max_epochs) + trainer = Trainer.from_argparse_args(args) trainer.fit(model) if __name__ == '__main__': @@ -182,6 +215,8 @@ Now we can allow each model to inject the arguments it needs in the main.py # figure out which model to use parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist') + + # THIS LINE IS KEY TO PULL THE MODEL NAME temp_args = parser.parse_known_args() # let the model add what it wants diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19130f7289858..e8fe9e2e8512b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -638,9 +638,11 @@ def arg_default(x): return parser @classmethod - def from_argparse_args(cls, args): + def from_argparse_args(cls, args, **kwargs): params = vars(args) + params.update(**kwargs) + return cls(**params) @property From 3bf3144d2ad6a9d620a5eb79d1496700e5ce605d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 07:59:46 -0400 Subject: [PATCH 11/13] fixed hparams section --- pytorch_lightning/trainer/logging.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 656c20265d66c..f52b4af864837 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -5,7 +5,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection -from pytorch_lightning.utilities import memory_utils +from pytorch_lightning.utilities import memory class TrainerLoggingMixin(ABC): @@ -174,7 +174,7 @@ def process_output(self, output, train=False): # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down - callback_metrics = memory_utils.recursive_detach(callback_metrics) + callback_metrics = memory.recursive_detach(callback_metrics) return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e8fe9e2e8512b..eeafb73f70322 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -32,7 +32,7 @@ from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import parsing_utils +from pytorch_lightning.utilities import parsing try: @@ -611,7 +611,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: for allowed_type in (at for at in allowed_types if at in arg_types): if allowed_type is bool: def allowed_type(x): - return bool(parsing_utils.strtobool(x)) + return bool(parsing.strtobool(x)) if arg == 'gpus': def allowed_type(x): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 33c8339bb85d0..be5d70f82e4a0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -158,7 +158,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import memory_utils +from pytorch_lightning.utilities import memory try: from apex import amp From 68f5c3fc7c5a108a9cacaf01f4144ac671653b33 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 26 Apr 2020 08:01:20 -0400 Subject: [PATCH 12/13] fixed hparams section --- pytorch_lightning/trainer/logging.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/utilities/{memory.py => memory_utils.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename pytorch_lightning/utilities/{memory.py => memory_utils.py} (100%) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index f52b4af864837..656c20265d66c 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -5,7 +5,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection -from pytorch_lightning.utilities import memory +from pytorch_lightning.utilities import memory_utils class TrainerLoggingMixin(ABC): @@ -174,7 +174,7 @@ def process_output(self, output, train=False): # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down - callback_metrics = memory.recursive_detach(callback_metrics) + callback_metrics = memory_utils.recursive_detach(callback_metrics) return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index be5d70f82e4a0..33c8339bb85d0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -158,7 +158,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import memory +from pytorch_lightning.utilities import memory_utils try: from apex import amp diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory_utils.py similarity index 100% rename from pytorch_lightning/utilities/memory.py rename to pytorch_lightning/utilities/memory_utils.py From ab73fa380f2abba69542735164e6abc257f773f1 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Sun, 26 Apr 2020 14:07:41 +0200 Subject: [PATCH 13/13] formatting --- pytorch_lightning/trainer/logging.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 3 --- pytorch_lightning/utilities/{memory_utils.py => memory.py} | 0 3 files changed, 2 insertions(+), 5 deletions(-) rename pytorch_lightning/utilities/{memory_utils.py => memory.py} (100%) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 656c20265d66c..c1d598dc71875 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -5,7 +5,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection -from pytorch_lightning.utilities import memory_utils +from pytorch_lightning.utilities.memory import recursive_detach class TrainerLoggingMixin(ABC): @@ -174,7 +174,7 @@ def process_output(self, output, train=False): # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down - callback_metrics = memory_utils.recursive_detach(callback_metrics) + callback_metrics = recursive_detach(callback_metrics) return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 33c8339bb85d0..8697270cf3de6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -141,7 +141,6 @@ def training_step(self, batch, batch_idx): """ -import copy from abc import ABC, abstractmethod from typing import Callable from typing import Union, List @@ -154,11 +153,9 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities import memory_utils try: from apex import amp diff --git a/pytorch_lightning/utilities/memory_utils.py b/pytorch_lightning/utilities/memory.py similarity index 100% rename from pytorch_lightning/utilities/memory_utils.py rename to pytorch_lightning/utilities/memory.py