diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 823b9830272fa..8082cbac3eae2 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -3,13 +3,11 @@ import operator from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple +from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_only - class LightningLoggerBase(ABC): """ @@ -174,9 +172,9 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, dict): + if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): - if isinstance(value, (dict, Namespace)): + if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): yield d