Skip to content

Commit

Permalink
Merge pull request ultralytics#245 from yxNONG/patch-2
Browse files Browse the repository at this point in the history
Unify the check point of single and multi GPU
  • Loading branch information
glenn-jocher committed Jul 2, 2020
2 parents 5e0bf24 + c3c6ba0 commit c1f4b79
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def train(hyp):
# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
model.names = data_dict['names']

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
Expand Down Expand Up @@ -178,6 +177,7 @@ def train(hyp):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

# Class frequency
labels = np.concatenate(dataset.labels, 0)
Expand Down Expand Up @@ -294,7 +294,7 @@ def train(hyp):
batch_size=batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
model=ema.ema.module if hasattr(model, 'module') else ema.ema,
single_cls=opt.single_cls,
dataloader=testloader)

Expand Down
20 changes: 13 additions & 7 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def time_synchronized():
return time.time()


def is_parallel(model):
# is model is parallel with DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)


def initialize_weights(model):
for m in model.modules():
t = type(m)
Expand Down Expand Up @@ -111,8 +116,8 @@ def model_info(model, verbose=False):

try: # FLOPS
from thop import profile
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
except:
fs = ''

Expand Down Expand Up @@ -185,7 +190,7 @@ def update(self, model):
self.updates += 1
d = self.decay(self.updates)
with torch.no_grad():
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
if is_parallel(model):
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
else:
msd, esd = model.state_dict(), self.ema.state_dict()
Expand All @@ -196,7 +201,8 @@ def update(self, model):
v += (1. - d) * msd[k].detach()

def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
if not k.startswith('_'):
setattr(self.ema, k, getattr(model, k))
# Update class attributes
ema = self.ema.module if is_parallel(model) else self.ema
for k, v in model.__dict__.items():
if not k.startswith('_') and k != 'module':
setattr(ema, k, v)

0 comments on commit c1f4b79

Please sign in to comment.