diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 595a5e84bf630..89c72883fc497 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pprint import pprint -from typing import Dict, Iterable, Union +from typing import Dict, Iterable, Optional, Union import torch @@ -32,7 +32,7 @@ class LoggerConnector: - def __init__(self, trainer, log_gpu_memory: bool): + def __init__(self, trainer, log_gpu_memory: Optional[str] = None): self.trainer = trainer self.log_gpu_memory = log_gpu_memory self._callback_metrics = MetricsHolder()