Skip to content

Commit

Permalink
Multi-gpu patch
Browse files Browse the repository at this point in the history
Issue: low val mAP when training on multi-gpu setting
Fix:   Serialize --> Deserialize model for evalution. Do not use same model for training & eval.

Test.test updated to work with new requirements of train.py

Now, training gives correct mAP even on multi-gpu setting.
  • Loading branch information
akshaychawla committed Aug 17, 2020
1 parent 64ff05c commit 4e5eaab
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
15 changes: 11 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ def test(cfg,
# Initialize/load model and set device
if model is None:
is_training = False
device = torch_utils.select_device(opt.device, batch_size=batch_size)
verbose = opt.task == 'test'
if 'opt' not in globals():
# test.test called from train.py but w/o passing model argument
# drawback of this patch: due to is_training=False, valid loss not computed
device = torch.device('cuda:0')
verbose= False
else:
device = torch_utils.select_device(opt.device, batch_size=batch_size)
verbose = opt.task == 'test'

# Remove previous
for f in glob.glob('test_batch*.jpg'):
Expand Down Expand Up @@ -63,7 +69,7 @@ def test(cfg,

# Dataloader
if dataloader is None:
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=opt.single_cls, pad=0.5)
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=single_cls, pad=0.5)
batch_size = min(batch_size, len(dataset))
dataloader = DataLoader(dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -243,8 +249,9 @@ def test(cfg,
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--no-pycocotools', action='store_true', help='do not evaluate using pycocotools')
opt = parser.parse_args()
opt.save_json = opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']])
opt.save_json = (not opt.no_pycocotools) and (opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]))
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
print(opt)
Expand Down
18 changes: 7 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from models import *
from utils.datasets import *
from utils.utils import *
import subprocess, tempfile

mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
Expand Down Expand Up @@ -199,7 +200,7 @@ def train(hyp):

# Dataloader
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
nw = opt.nw # min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
Expand Down Expand Up @@ -321,16 +322,10 @@ def train(hyp):
ema.update_attr(model)
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
results, maps = test.test(cfg,
data,
batch_size=batch_size,
imgsz=imgsz_test,
model=ema.ema,
save_json=final_epoch and is_coco,
single_cls=opt.single_cls,
dataloader=testloader,
multi_label=ni > n_burn)
with tempfile.TemporaryDirectory() as tmpdirname:
tmpckpt = {'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict()}
torch.save(tmpckpt, os.path.join(tmpdirname, 'ckpt.pt'))
results, maps = test.test(cfg=opt.cfg, weights=os.path.join(tmpdirname, 'ckpt.pt'), data=opt.data, imgsz=imgsz_test, save_json=False)

# Write
with open(results_file, 'a') as f:
Expand Down Expand Up @@ -407,6 +402,7 @@ def train(hyp):
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics.pt', help='initial weights path')
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--nw', type=int, default=8, help='number of dataloader workers')
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--freeze-layers', action='store_true', help='Freeze non-output layers')
Expand Down

0 comments on commit 4e5eaab

Please sign in to comment.