Skip to content

Commit

Permalink
Model freeze capability (ultralytics#679)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 11, 2020
1 parent 1ecd197 commit df68038
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def train(hyp, opt, device, tb_writer=None):
else:
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create

# Freeze
freeze = ['', ] # parameter names to freeze (full or partial)
if any(freeze):
for k, v in model.named_parameters():
if any(x in k for x in freeze):
print('freezing %s' % k)
v.requires_grad = False

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
Expand Down Expand Up @@ -125,7 +133,7 @@ def train(hyp, opt, device, tb_writer=None):
epochs += ckpt['epoch'] # finetune additional epochs

del ckpt, state_dict

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
Expand Down

0 comments on commit df68038

Please sign in to comment.