From 6709088ade19e13901e464260622af5bf5399d23 Mon Sep 17 00:00:00 2001 From: Siavash Sakhavi Date: Wed, 17 Jun 2020 11:22:25 +0800 Subject: [PATCH 1/3] Checking if the parameters are a DictConfig Object This is in reference to #2058 . To be honest, I have no idea how I should go about writing a test for this. --- pytorch_lightning/loggers/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 823b9830272fa..16eb65a55c099 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -9,6 +9,7 @@ import torch from pytorch_lightning.utilities import rank_zero_only +from omegaconf import DictConfig class LightningLoggerBase(ABC): @@ -174,9 +175,9 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, dict): + if isinstance(input_dict, (dict, DictConfig)): for key, value in input_dict.items(): - if isinstance(value, (dict, Namespace)): + if isinstance(value, (dict, DictConfig, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): yield d From fbdb41d8dc1e7e660400fa04124d136ab1856871 Mon Sep 17 00:00:00 2001 From: Siavash Sakhavi Date: Wed, 17 Jun 2020 16:25:04 +0800 Subject: [PATCH 2/3] Update pytorch_lightning/loggers/base.py Co-authored-by: Jirka Borovec --- pytorch_lightning/loggers/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 16eb65a55c099..626f074212080 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -9,7 +9,10 @@ import torch from pytorch_lightning.utilities import rank_zero_only -from omegaconf import DictConfig +try: + from omegaconf import Container +except ImportError: + Container = None class LightningLoggerBase(ABC): From 33deec2c9fb52ac3008251d35700b41b9cccbea8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 19 Jun 2020 20:44:47 +0200 Subject: [PATCH 3/3] fix ... --- pytorch_lightning/loggers/base.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 626f074212080..8082cbac3eae2 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -3,17 +3,11 @@ import operator from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple +from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_only -try: - from omegaconf import Container -except ImportError: - Container = None - class LightningLoggerBase(ABC): """ @@ -178,9 +172,9 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] - if isinstance(input_dict, (dict, DictConfig)): + if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): - if isinstance(value, (dict, DictConfig, Namespace)): + if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): yield d