diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index d2769c3e8e25c..6ba0ff8678b21 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -145,6 +145,7 @@ def train_fx(trial_hparams, cluster_manager, _): from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -204,6 +205,7 @@ class TrainerDDPMixin(ABC): node_rank: int tpu_cores: int testing: bool + datamodule: Optional[LightningDataModule] @property @abstractmethod