Skip to content

Commit

Permalink
Merge pull request #16 from neuralmagic/fix-sparsezoo-load-bug
Browse files Browse the repository at this point in the history
relax strict type check in sparsezoo load function
  • Loading branch information
markurtz committed Jul 29, 2021
2 parents 150b836 + d20dc38 commit 372ffa9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe=None, resume=None, rank=-1):
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
check_download_sparsezoo_weights(weights) # download from sparsezoo if zoo stub
weights = check_download_sparsezoo_weights(weights) # download from sparsezoo if zoo stub
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
else weights, map_location=device) # load checkpoint
start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
Expand Down
6 changes: 3 additions & 3 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def check_download_sparsezoo_weights(path):

return path

if not isinstance(path, list):
raise ValueError(f"unknown type for path given: {path}")
if isinstance(path, list):
return [check_download_sparsezoo_weights(p) for p in path]

return [check_download_sparsezoo_weights(p) for p in path]
return path


class SparseMLWrapper(object):
Expand Down

0 comments on commit 372ffa9

Please sign in to comment.