From df680380d4584dba692e7576d033077e330b4f36 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 10 Aug 2020 22:49:43 -0700 Subject: [PATCH] Model freeze capability (#679) --- train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 86a9dbc291c7..e6a1d15cdd40 100644 --- a/train.py +++ b/train.py @@ -73,6 +73,14 @@ def train(hyp, opt, device, tb_writer=None): else: model = Model(opt.cfg, ch=3, nc=nc).to(device) # create + # Freeze + freeze = ['', ] # parameter names to freeze (full or partial) + if any(freeze): + for k, v in model.named_parameters(): + if any(x in k for x in freeze): + print('freezing %s' % k) + v.requires_grad = False + # Optimizer nbs = 64 # nominal batch size accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing @@ -125,7 +133,7 @@ def train(hyp, opt, device, tb_writer=None): epochs += ckpt['epoch'] # finetune additional epochs del ckpt, state_dict - + # Image sizes gs = int(max(model.stride)) # grid size (max stride) imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples