Skip to content

Commit

Permalink
specify unsupported symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
deanp70 committed Jun 26, 2023
1 parent bc3869f commit 3b8818b
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from pathlib import Path
from typing import Optional, Mapping

Expand Down Expand Up @@ -176,6 +177,14 @@ def _get_nested_dict_values(self, d, parent_key="", sep="/"):
items.append((new_key, v))
return items

@multi_process_safe
def _contains_special_characters(self, text):
pattern = r"[!\"#$%&'()*+,:;<=>?@[\]^`{|}~\t\n\r\x0b\x0c]"
matches = re.findall(pattern, text)
if matches:
return True, ", ".join(matches)
return False, None

@multi_process_safe
def add_config(self, tag: str, config: dict):
super(DagsHubSGLogger, self).add_config(tag=tag, config=config)
Expand All @@ -184,15 +193,25 @@ def add_config(self, tag: str, config: dict):
try:
mlflow.log_params({k: v})
except Exception as e:
logger.debug(e)
is_contain, spec_char = self._contains_special_characters(k)
if is_contain:
err_msg = f"Fail to log {k}, please remove the unsupported characters: {spec_char}"
else:
err_msg = f"Fail to log the config: {k}, got an expection: {e}"
logger.warning(err_msg)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = 0):
super(DagsHubSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
try:
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
except Exception as e:
logger.debug(e)
is_contain, spec_char = self._contains_special_characters(tag)
if is_contain:
err_msg = f"Fail to log {tag}, please remove the unsupported characters: {spec_char}"
else:
err_msg = f"Fail to log the metric: {tag}, got an expection: {e}"
raise Exception(err_msg)

@multi_process_safe
def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
Expand All @@ -209,7 +228,7 @@ def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
v = float(v)
self.add_scalar(tag=k.replace("@", "at"), scalar_value=v, global_step=global_step)
except Exception as e:
logger.debug(f"error: {e}")
logger.warning(e)

@multi_process_safe
def close(self):
Expand Down

0 comments on commit 3b8818b

Please sign in to comment.