Skip to content

Commit

Permalink
Merge branch 'main' into torchvision-dp-support
Browse files Browse the repository at this point in the history
  • Loading branch information
corey-nm committed Jan 26, 2023
2 parents 86f9575 + 8353866 commit f3550bb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/sparseml/pytorch/image_classification/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,8 @@ def main(

if recipe is not None:
ScheduledModifierManager.from_yaml(recipe).apply_structure(model)

if checkpoint_path:
load_model(checkpoint_path, model, strict=True)
if checkpoint_path:
load_model(checkpoint_path, model, strict=True)

if one_shot is not None:
ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model)
Expand Down
6 changes: 4 additions & 2 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def collate_fn(batch):

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
args.distill_teacher = _create_model(
distill_teacher = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
Expand All @@ -393,6 +393,8 @@ def collate_fn(batch):
device=device,
num_classes=num_classes,
)
else:
distill_teacher = args.distill_teacher

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -576,7 +578,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
model,
epoch=args.start_epoch,
loggers=logger,
distillation_teacher=args.distill_teacher,
distillation_teacher=distill_teacher,
)
step_wrapper = manager.modify(
model,
Expand Down

0 comments on commit f3550bb

Please sign in to comment.