diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7c8a89bffb87a..c97dbcfa2808f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -946,7 +946,9 @@ def fit( xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method) # load weights if not interrupted - self.load_spawn_weights(model) + if self.on_colab_kaggle: + self.load_spawn_weights(model) + self.model = model # ON CPU