Skip to content

Commit

Permalink
Moving mutiplication by 100 to accuracy computation, training log fix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavnmagic committed Jan 26, 2023
1 parent f7e325c commit a5a8ca7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ def train_one_epoch(
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)

acc1, _1, acc5, _5 = utils.accuracy(output, target, topk=(1, 5))
acc1, num_correct_1, acc5, num_correct_5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["acc1"].update(
acc1.item(), n=batch_size, total=num_correct_1)
metric_logger.meters["acc5"].update(
acc5.item(), n=batch_size, total=num_correct_5)
metric_logger.meters["imgs_per_sec"].update(
batch_size / (time.time() - start_time)
)
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/pytorch/torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def avg(self):

@property
def global_avg(self):
return 100.0 * self.total / self.count
return self.total / self.count

@property
def max(self):
Expand Down Expand Up @@ -213,7 +213,7 @@ def accuracy(output, target, topk=(1,)):
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
res.append(correct_k)
res.append(correct_k * 100.0)
return res


Expand Down

0 comments on commit a5a8ca7

Please sign in to comment.