diff --git a/Dockerfile b/Dockerfile index 8840596483fc..59b1a03d210c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch -FROM nvcr.io/nvidia/pytorch:20.08-py3 +FROM nvcr.io/nvidia/pytorch:20.09-py3 # Install dependencies RUN pip install --upgrade pip diff --git a/data/hyp.finetune.yaml b/data/hyp.finetune.yaml index fe9cd55019f7..1b84cff95c2c 100644 --- a/data/hyp.finetune.yaml +++ b/data/hyp.finetune.yaml @@ -15,7 +15,7 @@ weight_decay: 0.00036 warmup_epochs: 2.0 warmup_momentum: 0.5 warmup_bias_lr: 0.05 -giou: 0.0296 +box: 0.0296 cls: 0.243 cls_pw: 0.631 obj: 0.301 diff --git a/data/hyp.scratch.yaml b/data/hyp.scratch.yaml index 9f53e86dd3ab..43354316c095 100644 --- a/data/hyp.scratch.yaml +++ b/data/hyp.scratch.yaml @@ -10,7 +10,7 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4 warmup_epochs: 3.0 # warmup epochs (fractions ok) warmup_momentum: 0.8 # warmup initial momentum warmup_bias_lr: 0.1 # warmup initial bias lr -giou: 0.05 # box loss gain +box: 0.05 # box loss gain cls: 0.5 # cls loss gain cls_pw: 1.0 # cls BCELoss positive_weight obj: 1.0 # obj loss gain (scale with pixels) diff --git a/data/scripts/get_coco.sh b/data/scripts/get_coco.sh index 7f86377070a5..157a0b04cf86 100755 --- a/data/scripts/get_coco.sh +++ b/data/scripts/get_coco.sh @@ -8,14 +8,17 @@ # /yolov5 # Download/unzip labels -echo 'Downloading COCO 2017 labels ...' d='../' # unzip directory -f='coco2017labels.zip' && curl -L https://github.com/ultralytics/yolov5/releases/download/v1.0/$f -o $f -unzip -q $f -d $d && rm $f +url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ +f='coco2017labels.zip' # 68 MB +echo 'Downloading' $url$f ' ...' && curl -L $url$f -o $f && unzip -q $f -d $d && rm $f # download, unzip, remove # Download/unzip images -echo 'Downloading COCO 2017 images ...' d='../coco/images' # unzip directory -f='train2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 19G, 118k images -f='val2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 1G, 5k images -# f='test2017.zip' && curl http://images.cocodataset.org/zips/$f -o $f && unzip -q $f -d $d && rm $f # 7G, 41k images +url=http://images.cocodataset.org/zips/ +f1='train2017.zip' # 19G, 118k images +f2='val2017.zip' # 1G, 5k images +f3='test2017.zip' # 7G, 41k images (optional) +for f in $f1 $f2; do + echo 'Downloading' $url$f ' ...' && curl -L $url$f -o $f && unzip -q $f -d $d && rm $f # download, unzip, remove +done diff --git a/data/scripts/get_voc.sh b/data/scripts/get_voc.sh index 5658864f2251..5e488c827c59 100644 --- a/data/scripts/get_voc.sh +++ b/data/scripts/get_voc.sh @@ -8,79 +8,23 @@ # /yolov5 start=$(date +%s) - -# handle optional download dir -if [ -z "$1" ]; then - # navigate to ~/tmp - echo "navigating to ../tmp/ ..." - mkdir -p ../tmp - cd ../tmp/ -else - # check if is valid directory - if [ ! -d $1 ]; then - echo $1 "is not a valid directory" - exit 0 - fi - echo "navigating to" $1 "..." - cd $1 -fi - -echo "Downloading VOC2007 trainval ..." -# Download data -curl -LO http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar -echo "Downloading VOC2007 test data ..." -curl -LO http://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar -echo "Done downloading." - -# Extract data -echo "Extracting trainval ..." -tar -xf VOCtrainval_06-Nov-2007.tar -echo "Extracting test ..." -tar -xf VOCtest_06-Nov-2007.tar -echo "removing tars ..." -rm VOCtrainval_06-Nov-2007.tar -rm VOCtest_06-Nov-2007.tar - -end=$(date +%s) -runtime=$((end - start)) - -echo "Completed in" $runtime "seconds" - -start=$(date +%s) - -# handle optional download dir -if [ -z "$1" ]; then - # navigate to ~/tmp - echo "navigating to ../tmp/ ..." - mkdir -p ../tmp - cd ../tmp/ -else - # check if is valid directory - if [ ! -d $1 ]; then - echo $1 "is not a valid directory" - exit 0 - fi - echo "navigating to" $1 "..." - cd $1 -fi - -echo "Downloading VOC2012 trainval ..." -# Download data -curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar -echo "Done downloading." - -# Extract data -echo "Extracting trainval ..." -tar -xf VOCtrainval_11-May-2012.tar -echo "removing tar ..." -rm VOCtrainval_11-May-2012.tar +mkdir -p ../tmp +cd ../tmp/ + +# Download/unzip images and labels +d='.' # unzip directory +url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ +f1=VOCtrainval_06-Nov-2007.zip # 446MB, 5012 images +f2=VOCtest_06-Nov-2007.zip # 438MB, 4953 images +f3=VOCtrainval_11-May-2012.zip # 1.95GB, 17126 images +for f in $f1 $f2 $f3; do + echo 'Downloading' $url$f ' ...' && curl -L $url$f -o $f && unzip -q $f -d $d && rm $f # download, unzip, remove +done end=$(date +%s) runtime=$((end - start)) - echo "Completed in" $runtime "seconds" -cd ../tmp echo "Spliting dataset..." python3 - "$@" <=4.41.0 # export -------------------------------------- # packaging # for coremltools -# coremltools==4.0b4 +# coremltools==4.0 # onnx>=1.7.0 # scikit-learn==0.19.2 # for coreml quantization diff --git a/sotabench.py b/sotabench.py index daef5168b213..9507d0754e95 100644 --- a/sotabench.py +++ b/sotabench.py @@ -1,6 +1,5 @@ import argparse import glob -import json import os import shutil from pathlib import Path @@ -8,19 +7,17 @@ import numpy as np import torch import yaml +from sotabencheval.object_detection import COCOEvaluator +from sotabencheval.utils import is_server from tqdm import tqdm from models.experimental import attempt_load from utils.datasets import create_dataloader from utils.general import ( coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, non_max_suppression, scale_coords, - xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class, set_logging) + xyxy2xywh, clip_coords, set_logging) from utils.torch_utils import select_device, time_synchronized - -from sotabencheval.object_detection import COCOEvaluator -from sotabencheval.utils import is_server - DATA_ROOT = './.data/vision/coco' if is_server() else '../coco' # sotabench data dir @@ -113,7 +110,7 @@ def test(data, # Compute loss if training: # if model has loss hyperparameters - loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls + loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls # Run NMS t = time_synchronized() diff --git a/test.py b/test.py index e0bb7726f7d1..9e79a769f884 100644 --- a/test.py +++ b/test.py @@ -106,7 +106,7 @@ def test(data, # Compute loss if training: # if model has loss hyperparameters - loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls + loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls # Run NMS t = time_synchronized() diff --git a/train.py b/train.py index 4060a5701a8b..42774b880fc6 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,13 @@ import argparse import logging -import math import os import random import shutil import time from pathlib import Path +from warnings import warn +import math import numpy as np import torch.distributed as dist import torch.nn.functional as F @@ -195,7 +196,7 @@ def train(hyp, opt, device, tb_writer=None): hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model - model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) + model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.names = names @@ -204,10 +205,11 @@ def train(hyp, opt, device, tb_writer=None): nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class - results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' + results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = amp.GradScaler(enabled=cuda) - logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n' + logger.info('Image sizes %g train, %g test\n' + 'Using %g dataloader workers\nLogging results to %s\n' 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs)) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() @@ -234,7 +236,7 @@ def train(hyp, opt, device, tb_writer=None): if rank != -1: dataloader.sampler.set_epoch(epoch) pbar = enumerate(dataloader) - logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) + logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size')) if rank in [-1, 0]: pbar = tqdm(pbar, total=nb) # progress bar optimizer.zero_grad() @@ -245,7 +247,7 @@ def train(hyp, opt, device, tb_writer=None): # Warmup if ni <= nw: xi = [0, nw] # x interp - # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) + # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 @@ -319,21 +321,21 @@ def train(hyp, opt, device, tb_writer=None): # Write with open(results_file, 'a') as f: - f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) # Tensorboard if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss + tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'x/lr0', 'x/lr1', 'x/lr2'] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): tb_writer.add_scalar(tag, x, epoch) # Update best mAP - fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] + fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi @@ -393,7 +395,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') - parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') + parser.add_argument('--name', default='', help='renames experiment folder exp{N} to exp{N}_{name} if supplied') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') @@ -429,9 +431,8 @@ def train(hyp, opt, device, tb_writer=None): opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1 - device = select_device(opt.device, batch_size=opt.batch_size) - # DDP mode + device = select_device(opt.device, batch_size=opt.batch_size) if opt.local_rank != -1: assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) @@ -440,15 +441,20 @@ def train(hyp, opt, device, tb_writer=None): assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' opt.batch_size = opt.total_batch_size // opt.world_size - logger.info(opt) + # Hyperparameters with open(opt.hyp) as f: hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps + if 'box' not in hyp: + warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' % + (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120')) + hyp['box'] = hyp.pop('giou') # Train + logger.info(opt) if not opt.evolve: tb_writer = None if opt.global_rank in [-1, 0]: - logger.info('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir) + logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.logdir}", view at http://localhost:6006/') tb_writer = SummaryWriter(log_dir=log_dir) # runs/exp0 train(hyp, opt, device, tb_writer) @@ -463,7 +469,7 @@ def train(hyp, opt, device, tb_writer=None): 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr - 'giou': (1, 0.02, 0.2), # GIoU loss gain + 'box': (1, 0.02, 0.2), # box loss gain 'cls': (1, 0.2, 4.0), # cls loss gain 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) @@ -488,7 +494,7 @@ def train(hyp, opt, device, tb_writer=None): assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' opt.notest, opt.nosave = True, True # only test/save final epoch # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices - yaml_file = Path('runs/evolve/hyp_evolved.yaml') # save best result here + yaml_file = Path(opt.logdir) / 'evolve' / 'hyp_evolved.yaml' # save best result here if opt.bucket: os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists @@ -532,5 +538,5 @@ def train(hyp, opt, device, tb_writer=None): # Plot results plot_evolution(yaml_file) - print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these ' - 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file)) + print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n' + f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') diff --git a/utils/datasets.py b/utils/datasets.py index 29ee4b051e85..9192dec4b7d9 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1,5 +1,4 @@ import glob -import math import os import random import shutil @@ -8,6 +7,7 @@ from threading import Thread import cv2 +import math import numpy as np import torch from PIL import Image, ExifTags diff --git a/utils/general.py b/utils/general.py index da016631e589..f8415feef999 100755 --- a/utils/general.py +++ b/utils/general.py @@ -1,9 +1,9 @@ import glob import logging -import math import os import platform import random +import re import shutil import subprocess import time @@ -12,6 +12,7 @@ from pathlib import Path import cv2 +import math import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -142,7 +143,7 @@ def check_dataset(dict): if val and len(val): val = [os.path.abspath(x) for x in (val if isinstance(val, list) else [val])] # val path if not all(os.path.exists(x) for x in val): - print('\nWARNING: Dataset not found, nonexistant paths: %s' % [*val]) + print('\nWARNING: Dataset not found, nonexistent paths: %s' % [*val]) if s and len(s): # download script print('Downloading %s ...' % s) if s.startswith('http') and s.endswith('.zip'): # URL @@ -157,7 +158,7 @@ def check_dataset(dict): def make_divisible(x, divisor): - # Returns x evenly divisble by divisor + # Returns x evenly divisible by divisor return math.ceil(x / divisor) * divisor @@ -168,9 +169,9 @@ def labels_to_class_weights(labels, nc=80): labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO classes = labels[:, 0].astype(np.int) # labels = [class xywh] - weights = np.bincount(classes, minlength=nc) # occurences per class + weights = np.bincount(classes, minlength=nc) # occurrences per class - # Prepend gridpoint count (for uCE trianing) + # Prepend gridpoint count (for uCE training) # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start @@ -509,11 +510,11 @@ def compute_loss(p, targets, model): # predictions, targets, model pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box - giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target) - lbox += (1.0 - giou).mean() # giou loss + iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss # Objectness - tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio + tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio # Classification if model.nc > 1: # cls loss (only if multiple classes) @@ -528,7 +529,7 @@ def compute_loss(p, targets, model): # predictions, targets, model lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss s = 3 / np # output count scaling - lbox *= h['giou'] * s + lbox *= h['box'] * s lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) lcls *= h['cls'] * s bs = tobj.shape[0] # batch size @@ -819,7 +820,7 @@ def print_results(k): k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k *= s wh = torch.tensor(wh, dtype=torch.float32) # filtered - wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered + wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered k = print_results(k) # Plot @@ -952,9 +953,12 @@ def increment_dir(dir, comment=''): # Increments a directory runs/exp1 --> runs/exp2_comment n = 0 # number dir = str(Path(dir)) # os-agnostic - d = sorted(glob.glob(dir + '*')) # directories - if len(d): - n = max([int(x[len(dir):x.rfind('_') if '_' in Path(x).name else None]) for x in d]) + 1 # increment + dirs = sorted(glob.glob(dir + '*')) # directories + if dirs: + matches = [re.search(r"exp(\d+)", d) for d in dirs] + idxs = [int(m.groups()[0]) for m in matches if m] + if idxs: + n = max(idxs) + 1 # increment return dir + str(n) + ('_' + comment if comment else '') @@ -1234,7 +1238,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay() # Plot training 'results*.txt', overlaying train and val losses s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends - t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles + t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T n = results.shape[1] # number of rows @@ -1254,13 +1258,13 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_ fig.savefig(f.replace('.txt', '.png'), dpi=200) -def plot_results(start=0, stop=0, bucket='', id=(), labels=(), - save_dir=''): # from utils.general import *; plot_results() +def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): + # from utils.general import *; plot_results() # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training fig, ax = plt.subplots(2, 5, figsize=(12, 6)) ax = ax.ravel() - s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall', - 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] + s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', + 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] if bucket: # os.system('rm -rf storage.googleapis.com') # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] @@ -1277,7 +1281,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), for i in range(10): y = results[i, x] if i in [0, 1, 2, 5, 6, 7]: - y[y == 0] = np.nan # dont show zero loss values + y[y == 0] = np.nan # don't show zero loss values # y /= y[0] # normalize label = labels[fi] if len(labels) else Path(f).stem ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c587617b821c..f6818238452f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,9 +1,9 @@ import logging -import math import os import time from copy import deepcopy +import math import torch import torch.backends.cudnn as cudnn import torch.nn as nn