Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves docs and handling of entities and resuming by WandbLogger #3264

Merged
merged 4 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def train(hyp, opt, device, tb_writer=None):
if wandb_logger.wandb and not opt.evolve: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped'])
aliases=['latest', 'last', 'best', 'stripped'])
wandb_logger.finish_run()
else:
dist.destroy_process_group()
Expand Down
29 changes: 22 additions & 7 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Utilities and tools for tracking runs with Weights & Biases."""
import json
import sys
from pathlib import Path
Expand Down Expand Up @@ -35,18 +36,19 @@ def get_run_info(run_path):
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem
project = run_path.parent.stem
entity = run_path.parent.parent.stem
model_artifact_name = 'run_' + run_id + '_model'
return run_id, project, model_artifact_name
return entity, project, run_id, model_artifact_name


def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
run_id, project, model_artifact_name = get_run_info(opt.resume)
entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api()
artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
modeldir = artifact.download()
opt.weights = str(Path(modeldir) / "last.pt")
return True
Expand Down Expand Up @@ -78,23 +80,36 @@ def process_wandb_config_ddp_mode(opt):


class WandbLogger():
"""Log training runs, datasets, models, and predictions to Weights & Biases.

This logger sends information to W&B at wandb.ai. By default, this information
includes hyperparameters, system configuration and metrics, model metrics,
and basic data metrics and analyses.

By providing additional command line arguments to train.py, datasets,
models and predictions can also be logged.

For more on how this logger is used, see the Weights & Biases documentation:
https://docs.wandb.com/guides/integrations/yolov5
"""
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine --
self.job_type = job_type
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
run_id, project, model_artifact_name = get_run_info(opt.resume)
entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
assert wandb, 'install wandb to resume wandb runs'
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
self.wandb_run = wandb.init(id=run_id, project=project, entity=entity, resume='allow')
opt.resume = model_artifact_name
elif self.wandb:
self.wandb_run = wandb.init(config=opt,
resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
entity=opt.entity,
name=name,
job_type=job_type,
id=run_id) if not wandb.run else wandb.run
Expand Down Expand Up @@ -188,7 +203,7 @@ def log_model(self, path, opt, epoch, fitness_score, best_model=False):
})
model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
wandb.log_artifact(model_artifact,
aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
print("Saving model artifact on epoch ", epoch + 1)

def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
Expand Down Expand Up @@ -291,7 +306,7 @@ def end_epoch(self, best_result=False):
if self.result_artifact:
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
self.result_artifact.add(train_results, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
('best' if best_result else '')])
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
Expand Down