From 286106dc8577418114fb3af49ea5b5aaa4344eea Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sun, 28 Mar 2021 19:41:36 +0530 Subject: [PATCH] W&B resume ddp from run link fix (#2579) * W&B resume ddp from run link fix * Native DDP W&B support for training, resuming --- train.py | 4 +- utils/wandb_logging/wandb_utils.py | 66 +++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/train.py b/train.py index 211cc04fb63b..d5b2d1b75c52 100644 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from utils.loss import ComputeLoss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel -from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id +from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume logger = logging.getLogger(__name__) @@ -496,7 +496,7 @@ def train(hyp, opt, device, tb_writer=None): check_requirements() # Resume - wandb_run = resume_and_get_id(opt) + wandb_run = check_wandb_resume(opt) if opt.resume and not wandb_run: # resume an interrupted run ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' diff --git a/utils/wandb_logging/wandb_utils.py b/utils/wandb_logging/wandb_utils.py index d6dd256366e0..17132874e0d0 100644 --- a/utils/wandb_logging/wandb_utils.py +++ b/utils/wandb_logging/wandb_utils.py @@ -23,7 +23,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' -def remove_prefix(from_string, prefix): +def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX): return from_string[len(prefix):] @@ -33,35 +33,73 @@ def check_wandb_config_file(data_config_file): return wandb_config return data_config_file +def get_run_info(run_path): + run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) + run_id = run_path.stem + project = run_path.parent.stem + model_artifact_name = 'run_' + run_id + '_model' + return run_id, project, model_artifact_name -def resume_and_get_id(opt): - # It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call +def check_wandb_resume(opt): + process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None if isinstance(opt.resume, str): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): - run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX)) - run_id = run_path.stem - project = run_path.parent.stem - model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model' - assert wandb, 'install wandb to resume wandb runs' - # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config - run = wandb.init(id=run_id, project=project, resume='allow') - opt.resume = model_artifact_name - return run + if opt.global_rank not in [-1, 0]: # For resuming DDP runs + run_id, project, model_artifact_name = get_run_info(opt.resume) + api = wandb.Api() + artifact = api.artifact(project + '/' + model_artifact_name + ':latest') + modeldir = artifact.download() + opt.weights = str(Path(modeldir) / "last.pt") + return True return None +def process_wandb_config_ddp_mode(opt): + with open(opt.data) as f: + data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict + train_dir, val_dir = None, None + if data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): + api = wandb.Api() + train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias) + train_dir = train_artifact.download() + train_path = Path(train_dir) / 'data/images/' + data_dict['train'] = str(train_path) + + if data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX): + api = wandb.Api() + val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias) + val_dir = val_artifact.download() + val_path = Path(val_dir) / 'data/images/' + data_dict['val'] = str(val_path) + if train_dir or val_dir: + ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') + with open(ddp_data_path, 'w') as f: + yaml.dump(data_dict, f) + opt.data = ddp_data_path + + class WandbLogger(): def __init__(self, opt, name, run_id, data_dict, job_type='Training'): # Pre-training routine -- self.job_type = job_type self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict - if self.wandb: + # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call + if isinstance(opt.resume, str): # checks resume from artifact + if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): + run_id, project, model_artifact_name = get_run_info(opt.resume) + model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name + assert wandb, 'install wandb to resume wandb runs' + # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config + self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') + opt.resume = model_artifact_name + elif self.wandb: self.wandb_run = wandb.init(config=opt, resume="allow", project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, name=name, job_type=job_type, - id=run_id) if not wandb.run else wandb.run + id=run_id) if not wandb.run else wandb.run + if self.wandb_run: if self.job_type == 'Training': if not opt.resume: wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict