diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e0a7b43a872aa..c8cb81ed090b1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -291,7 +291,7 @@ def transfer_batch_to_tpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def clip_gradients(self): + def clip_gradients(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -817,7 +817,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer): # ------------------ if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: self.scaler.unscale_(optimizer) - self.clip_gradients() + self.clip_gradients(optimizer) # ------------------ # .STEP + ZERO_GRAD diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 20eeff3878cc2..70070296b1a37 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -27,9 +27,17 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda +try: + from apex import amp +except ImportError: + APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True + EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -60,14 +68,17 @@ def restore(self, *args): def fit(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def clip_gradients(self): + def clip_gradients(self, optimizer): # this code is a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if self.gradient_clip_val <= 0: return model = self.get_model() - parameters = model.parameters() + if self.use_amp and not NATIVE_AMP_AVALAIBLE: + parameters = amp.master_params(optimizer) + else: + parameters = model.parameters() max_norm = float(self.gradient_clip_val) norm_type = float(2.0) if isinstance(parameters, torch.Tensor):