Skip to content

Commit

Permalink
fix compatibility for hyper config (#1146)
Browse files Browse the repository at this point in the history
* fix/hyper

* Hyp giou check to train.py

* restore general.py

* train.py overwrite fix

* restore general.py and pep8 update

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
Borda and glenn-jocher authored Oct 15, 2020
1 parent 4d3680c commit c67e722
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
11 changes: 8 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import time
from pathlib import Path
from warnings import warn

import math
import numpy as np
Expand Down Expand Up @@ -430,9 +431,8 @@ def train(hyp, opt, device, tb_writer=None):
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1

device = select_device(opt.device, batch_size=opt.batch_size)

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
Expand All @@ -441,11 +441,16 @@ def train(hyp, opt, device, tb_writer=None):
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
opt.batch_size = opt.total_batch_size // opt.world_size

logger.info(opt)
# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
if 'box' not in hyp:
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
hyp['box'] = hyp.pop('giou')

# Train
logger.info(opt)
if not opt.evolve:
tb_writer = None
if opt.global_rank in [-1, 0]:
Expand Down
4 changes: 2 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import glob
import logging
import math
import os
import platform
import random
import re
import shutil
import subprocess
import time
import re
from contextlib import contextmanager
from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down

0 comments on commit c67e722

Please sign in to comment.