From cfbd364a021471749e0e7c3014d299edf17f7d41 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 1 Sep 2022 22:47:36 +0530 Subject: [PATCH] Refactor Loggers : Move code outside train.py (#9241) * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- train.py | 11 +++++------ utils/loggers/__init__.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 0cd4a7f065a6..29293aa612cf 100644 --- a/train.py +++ b/train.py @@ -91,17 +91,16 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio data_dict = None if RANK in {-1, 0}: loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance - if loggers.clearml: - data_dict = loggers.clearml.data_dict # None if no ClearML dataset or filled in by ClearML - if loggers.wandb: - data_dict = loggers.wandb.data_dict - if resume: - weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size # Register actions for k in methods(loggers): callbacks.register_action(k, callback=getattr(loggers, k)) + # Process custom dataset artifact link + data_dict = loggers.remote_dataset + if resume: # If resuming runs from remote artifact + weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size + # Config plots = not evolve and not opt.noplots # create plots cuda = device.type != 'cpu' diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 880039b1914c..1aa8427f9127 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -107,6 +107,17 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, else: self.clearml = None + @property + def remote_dataset(self): + # Get data_dict if custom dataset artifact link is provided + data_dict = None + if self.clearml: + data_dict = self.clearml.data_dict + if self.wandb: + data_dict = self.wandb.data_dict + + return data_dict + def on_train_start(self): # Callback runs on train start pass