Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W&B DDP fix #2574

Merged
merged 1 commit into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None):
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
loggers = {'wandb': wandb_logger.wandb} # loggers dict

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
Expand Down Expand Up @@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None):
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi)

# Save model
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
Expand All @@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None):
wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt
wandb_logger.end_epoch(best_result=best_fitness == fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training
Expand Down Expand Up @@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None):
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped'])
wandb_logger.finish_run()
else:
dist.destroy_process_group()
torch.cuda.empty_cache()
wandb_logger.finish_run()
return results


Expand Down
5 changes: 4 additions & 1 deletion utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

try:
import wandb
from wandb import init, finish
except ImportError:
wandb = None
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")

WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'

Expand Down Expand Up @@ -71,6 +71,9 @@ def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
self.data_dict = self.setup_training(opt, data_dict)
if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt)
else:
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")


def check_and_upload_dataset(self, opt):
assert wandb, 'Install wandb to upload dataset'
Expand Down