Skip to content

Commit

Permalink
DagsHub Logger: Fix unsupported metric formats for MLflow, Add exampl…
Browse files Browse the repository at this point in the history
…e notebook (#915)

* DagsHub Logger: Fix the error when encounter the unsupported type for mlflow. Add a Colab Notebook using DagsHub Logger

* Update dagshub_sg_logger.py

* Update dagshub_sg_logger.py

* specify unsupported symbols

* Sanitize illegal chars for MLflow, and make sure timestep is correct

* Fixed black

---------

Co-authored-by: Dean <dean@dagshub.com>
Co-authored-by: Ran Rubin <ranrubin@gmail.com>
Co-authored-by: Dean P <deanp07@gmail.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
5 people committed Aug 8, 2023
1 parent 2d3004a commit f24954c
Show file tree
Hide file tree
Showing 2 changed files with 3,787 additions and 13 deletions.
63 changes: 50 additions & 13 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from pathlib import Path
from typing import Optional
from typing import Optional, Mapping

import torch

Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
def splitter(repo):
splitted = repo.split("/")
if len(splitted) != 2:
raise ValueError(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
raise Exception(f"Invalid input, should be owner_name/repo_name, but got {repo} instead")
return splitted[1], splitted[0]

def _init_env_dependency(self):
Expand Down Expand Up @@ -166,28 +167,64 @@ def _dvc_add(self, local_path="", remote_path=""):
def _dvc_commit(self, commit=""):
self.dvc_folder.commit(commit, versioning="dvc", force=True)

@multi_process_safe
def _get_nested_dict_values(self, d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, Mapping):
items.extend(self._get_nested_dict_values(v, new_key, sep=sep))
else:
items.append((new_key, v))
return items

@multi_process_safe
def _sanitize_special_characters(self, text):
pattern = r"[!\"#$%&'()*+,:;<=>?@[\]^`{|}~\t\n\r\x0b\x0c]"
valid_text = re.sub(pattern, "_", text)
return valid_text

@multi_process_safe
def add_config(self, tag: str, config: dict):
super(DagsHubSGLogger, self).add_config(tag=tag, config=config)
param_keys = config.keys()
for pk in param_keys:
for k, v in config[pk].items():
try:
mlflow.log_params({k: v})
except Exception:
logger.warning(f"Skip to log {k}: {v}")
flatten_dict = self._get_nested_dict_values(d=config)
for k, v in flatten_dict:
try:
k_sanitized = self._sanitize_special_characters(k)
mlflow.log_params({k_sanitized: v})
except Exception as e:
err_msg = f"Fail to log the config: {k}, got an expection: {e}"
logger.warning(err_msg)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: [int, TimeUnit] = 0):
super(DagsHubSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)
try:
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()

tag_sanitized = self._sanitize_special_characters(tag)
mlflow.log_metric(key=tag_sanitized, value=scalar_value, step=global_step)
except Exception as e:
err_msg = f"Fail to log the metric: {tag}, got an expection: {e}"
raise Exception(err_msg)

@multi_process_safe
def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
super(DagsHubSGLogger, self).add_scalars(tag_scalar_dict=tag_scalar_dict, global_step=global_step)
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
try:
mlflow.log_metrics(metrics=tag_scalar_dict, step=global_step)
except Exception:
flatten_dicts = self._get_nested_dict_values(tag_scalar_dict)
for k, v in flatten_dicts:
try:
if isinstance(v, torch.Tensor):
v = v.item()
else:
v = float(v)
self.add_scalar(tag=k.replace("@", "at"), scalar_value=v, global_step=global_step)
except Exception as e:
logger.warning(e)

@multi_process_safe
def close(self):
Expand Down
Loading

0 comments on commit f24954c

Please sign in to comment.