Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved W&B integration #2125

Merged
merged 84 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
ba39bfd
Init Commit
AyushExel Feb 2, 2021
5fcd1dc
new wandb integration
AyushExel Feb 3, 2021
c540e3b
Update
AyushExel Feb 3, 2021
8253b24
Use data_dict in test
AyushExel Feb 3, 2021
7f89535
Updates
AyushExel Feb 3, 2021
c149930
Update: scope of log_img
AyushExel Feb 3, 2021
49edb90
Update: scope of log_img
AyushExel Feb 3, 2021
7922683
Update
AyushExel Feb 3, 2021
e1e7179
Update: Fix logging conditions
AyushExel Feb 3, 2021
e632514
Add tqdm bar, support for .txt dataset format
AyushExel Feb 9, 2021
3e8f4ae
Improve Result table Logger
AyushExel Feb 21, 2021
cd094f3
Init Commit
AyushExel Feb 2, 2021
aa5231e
new wandb integration
AyushExel Feb 3, 2021
0fdf3d3
Update
AyushExel Feb 3, 2021
37a2ed6
Use data_dict in test
AyushExel Feb 3, 2021
ac7d4b1
Updates
AyushExel Feb 3, 2021
ebc1d18
Update: scope of log_img
AyushExel Feb 3, 2021
745a272
Update: scope of log_img
AyushExel Feb 3, 2021
7679454
Update
AyushExel Feb 3, 2021
b8210a7
Update: Fix logging conditions
AyushExel Feb 3, 2021
ac9a613
Add tqdm bar, support for .txt dataset format
AyushExel Feb 9, 2021
4f7c150
Improve Result table Logger
AyushExel Feb 21, 2021
b8bbfce
Merge branch 'wandb_clean' of https://github.com/AyushExel/yolov5 int…
AyushExel Feb 23, 2021
c1e6697
Add dataset creation in training script
AyushExel Feb 23, 2021
1948562
Change scope: self.wandb_run
AyushExel Feb 23, 2021
8848f3c
Add wandb-artifact:// natively
AyushExel Feb 25, 2021
deca116
Add suuport for logging dataset while training
AyushExel Feb 26, 2021
20185f2
Cleanup
AyushExel Feb 26, 2021
5287a79
Merge branch 'master' into wandb_clean
AyushExel Feb 26, 2021
e13994d
Fix: Merge conflict
AyushExel Feb 26, 2021
1080952
Fix: CI tests
AyushExel Feb 26, 2021
5a859d4
Automatically use wandb config
AyushExel Feb 27, 2021
519cb7d
Fix: Resume
AyushExel Feb 28, 2021
3242f52
Fix: CI
AyushExel Feb 28, 2021
8128216
Enhance: Using val_table
AyushExel Feb 28, 2021
043befa
More resume enhancement
AyushExel Feb 28, 2021
c2d98f0
FIX : CI
AyushExel Feb 28, 2021
dbb69f4
Add alias
AyushExel Feb 28, 2021
8505a58
Get useful opt config data
AyushExel Mar 1, 2021
04f8880
train.py cleanup
AyushExel Mar 2, 2021
27a33dd
Merge remote-tracking branch 'upstream/master' into wandb_clean
AyushExel Mar 2, 2021
54dee24
Cleanup train.py
AyushExel Mar 2, 2021
21a15a5
more cleanup
AyushExel Mar 2, 2021
d38c620
Cleanup| CI fix
AyushExel Mar 2, 2021
e5400ba
Reformat using PEP8
AyushExel Mar 3, 2021
45e2c55
FIX:CI
AyushExel Mar 3, 2021
75f31d0
Merge remote-tracking branch 'upstream/master' into wandb_clean
AyushExel Mar 6, 2021
613b102
rebase
AyushExel Mar 6, 2021
9772645
remove uneccesary changes
AyushExel Mar 6, 2021
cd1237e
remove uneccesary changes
AyushExel Mar 6, 2021
d172ba1
remove uneccesary changes
AyushExel Mar 6, 2021
7af0186
remove unecessary chage from test.py
AyushExel Mar 6, 2021
51dca6d
FIX: resume from local checkpoint
AyushExel Mar 8, 2021
1438483
FIX:resume
AyushExel Mar 8, 2021
e7d18c6
FIX:resume
AyushExel Mar 8, 2021
22d97a7
Reformat
AyushExel Mar 8, 2021
8e97cdf
Performance improvement
AyushExel Mar 9, 2021
2ffb643
Fix local resume
AyushExel Mar 9, 2021
7836d17
Fix local resume
AyushExel Mar 9, 2021
aa785ec
FIX:CI
AyushExel Mar 9, 2021
f97446e
Fix: CI
AyushExel Mar 9, 2021
807a0e1
Imporve image logging
AyushExel Mar 9, 2021
20b4450
(:(:Redo CI tests:):)
AyushExel Mar 9, 2021
db81c64
Remember epochs when resuming
AyushExel Mar 9, 2021
25ff6b8
Remember epochs when resuming
AyushExel Mar 9, 2021
819ebec
Update DDP location
glenn-jocher Mar 10, 2021
b23a902
merge master
glenn-jocher Mar 14, 2021
f742857
PEP8 reformat
glenn-jocher Mar 14, 2021
350b8ab
0.25 confidence threshold
glenn-jocher Mar 14, 2021
395379e
reset train.py plots syntax to previous
glenn-jocher Mar 14, 2021
a06b25c
reset epochs completed syntax to previous
glenn-jocher Mar 14, 2021
cc49f6a
reset space to previous
glenn-jocher Mar 14, 2021
2d56697
remove brackets
glenn-jocher Mar 14, 2021
ba859a6
reset comment to previous
glenn-jocher Mar 14, 2021
52e3e71
Update: is_coco check, remove unused code
AyushExel Mar 14, 2021
ad1ad8f
Remove redundant print statement
AyushExel Mar 14, 2021
72dd23b
Remove wandb imports
AyushExel Mar 14, 2021
ac955ab
remove dsviz logger from test.py
AyushExel Mar 14, 2021
8bded54
Remove redundant change from test.py
AyushExel Mar 14, 2021
1aca390
remove redundant changes from train.py
AyushExel Mar 14, 2021
4c1c9bf
reformat and improvements
AyushExel Mar 20, 2021
f4923b4
Fix typo
AyushExel Mar 21, 2021
af23506
Merge branch 'master' of https://github.com/ultralytics/yolov5 into w…
AyushExel Mar 21, 2021
ca06d31
Add tqdm tqdm progress when scanning files, naming improvements
AyushExel Mar 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir='
def print(self):
self.display(pprint=True) # print results
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
tuple(self.t))
tuple(self.t))

def show(self):
self.display(show=True) # show results
Expand Down
49 changes: 26 additions & 23 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def test(data,
save_hybrid=False, # for hybrid auto-labelling
save_conf=False, # save auto-label confidences
plots=True,
log_imgs=0, # number of logged images
compute_loss=None):
wandb_logger=None,
compute_loss=None,
is_coco=False):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -66,21 +67,19 @@ def test(data,

# Configure
model.eval()
is_coco = data.endswith('coco.yaml') # is COCO dataset
with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # model dict
if isinstance(data, str):
is_coco = data.endswith('coco.yaml')
with open(data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
check_dataset(data) # check
nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
niou = iouv.numel()

# Logging
log_imgs, wandb = min(log_imgs, 100), None # ceil
try:
import wandb # Weights & Biases
except ImportError:
log_imgs = 0

log_imgs = 0
if wandb_logger and wandb_logger.wandb:
log_imgs = min(wandb_logger.log_imgs, 100)
# Dataloader
if not training:
if device.type != 'cpu':
Expand Down Expand Up @@ -147,15 +146,17 @@ def test(data,
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')

# W&B logging
if plots and len(wandb_images) < log_imgs:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
# W&B logging - Media Panel Plots
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name))
wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables

# Append to pycocotools JSON dictionary
if save_json:
Expand Down Expand Up @@ -239,9 +240,11 @@ def test(data,
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False)
if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
wandb_logger.log({"Validation": val_batches})
if wandb_images:
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})

# Save JSON
if save_json and len(jdict):
Expand Down
116 changes: 61 additions & 55 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import argparse
import logging
import math
Expand Down Expand Up @@ -33,11 +34,12 @@
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file

logger = logging.getLogger(__name__)


def train(hyp, opt, device, tb_writer=None, wandb=None):
def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
Expand All @@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
loggers = {'wandb': wandb_logger.wandb} # loggers dict
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
Expand All @@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank):
check_dataset(data_dict) # check
train_path = data_dict['train']
test_path = data_dict['val']

# Freeze
freeze = [] # parameter names to freeze (full or partial)
Expand Down Expand Up @@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)

# Logging
if rank in [-1, 0] and wandb and wandb.run is None:
opt.hyp = hyp # add hyperparameters
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem,
entity=opt.entity,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

Expand Down Expand Up @@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard
elif plots and ni == 10 and wandb:
wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
if x.exists()]}, commit=False)
elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

# end batch ------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------
Expand All @@ -343,17 +346,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=batch_size * 2,
wandb_logger.current_epoch = epoch + 1
results, maps, times = test.test(data_dict,
batch_size=total_batch_size,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
log_imgs=opt.log_imgs if wandb else 0,
compute_loss=compute_loss)
wandb_logger=wandb_logger,
compute_loss=compute_loss,
is_coco=is_coco)

# Write
with open(results_file, 'a') as f:
Expand All @@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if tb_writer:
tb_writer.add_scalar(tag, x, epoch) # tensorboard
if wandb:
wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
if wandb_logger.wandb:
wandb_logger.log({tag: x}) # W&B

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
Expand All @@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
if wandb_logger.wandb:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt

wandb_logger.end_epoch(best_result=best_fitness == fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training

if rank in [-1, 0]:
# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f)
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload

# Plots
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb:
if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
if opt.log_artifacts:
wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)

wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
Expand All @@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False)
plots=False,
is_coco=is_coco)

# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['last', 'best', 'stripped'])
else:
dist.destroy_process_group()

wandb.run.finish() if wandb and wandb.run else None
torch.cuda.empty_cache()
wandb_logger.finish_run()
return results


Expand Down Expand Up @@ -464,15 +473,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--quad', action='store_true', help='quad dataloader')
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
opt = parser.parse_args()

# Set DDP variables
Expand All @@ -484,7 +495,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
check_requirements()

# Resume
if opt.resume: # resume an interrupted run
wandb_run = resume_and_get_id(opt)
if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
Expand Down Expand Up @@ -517,18 +529,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Train
logger.info(opt)
try:
import wandb
except ImportError:
wandb = None
prefix = colorstr('wandb: ')
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
if not opt.evolve:
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]:
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer, wandb)
train(hyp, opt, device, tb_writer)

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -602,7 +608,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device, wandb=wandb)
results = train(hyp.copy(), opt, device)

# Write mutation results
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
Expand Down
16 changes: 1 addition & 15 deletions utils/wandb_logging/log_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,14 @@
def create_dataset_artifact(opt):
with open(opt.data) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
logger = WandbLogger(opt, '', None, data, job_type='create_dataset')
nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset

# Update data.yaml with artifact links
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val')
path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path
data.pop('download', None) # download via artifact instead of predefined field 'download:'
with open(path, 'w') as f:
yaml.dump(data, f)
print("New Config file => ", path)
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml')
opt = parser.parse_args()

create_dataset_artifact(opt)
Loading