Skip to content

Commit

Permalink
requires grad after reset params
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 18, 2022
1 parent e08d568 commit 5c854fa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions classify/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def train(opt, device):
LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
reshape_classifier_output(model, nc) # update class count
for p in model.parameters():
p.requires_grad = True # for training
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
m.p = opt.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names
Expand Down

0 comments on commit 5c854fa

Please sign in to comment.