Skip to content

Commit

Permalink
sparsezoo download stub (#360)
Browse files Browse the repository at this point in the history
* sparsezoo download stub

* lint
  • Loading branch information
horheynm committed Aug 25, 2023
1 parent 24cfe83 commit 1f94aa9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
3 changes: 1 addition & 2 deletions src/sparsezoo/download_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def parse_args():
parser.add_argument(
"model_stub",
type=str,
help="Path to a SparseZoo model stub i.e. "
"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-moderate",
help="Path to a SparseZoo model stub i.e. " "zoo:opt-1.3b-opt_pretrain-base",
)

parser.add_argument(
Expand Down
25 changes: 9 additions & 16 deletions src/sparsezoo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
import argparse
import logging

from sparsezoo import Model, model_args_to_stub, search_models
from sparsezoo import Model, search_models
from sparsezoo.analytics import sparsezoo_analytics


Expand Down Expand Up @@ -281,7 +281,12 @@ def parse_args():
SEARCH_COMMAND,
description="Search for objects from the repo.",
)
add_model_arguments(download_parser, download_required=True)

download_parser.add_argument(
"stub",
help="Model stub. Please visit sparsezoo.neuralmagic.com to obtain",
)

add_model_arguments(search_parser)

download_parser.add_argument(
Expand Down Expand Up @@ -373,20 +378,8 @@ def main():

if args.command == DOWNLOAD_COMMAND:
LOGGER.info("Downloading files from model...")
stub = model_args_to_stub(
domain=args.domain,
sub_domain=args.sub_domain,
architecture=args.architecture,
sub_architecture=args.sub_architecture,
framework=args.framework,
repo=args.repo,
dataset=args.dataset,
training_scheme=args.training_scheme,
sparse_name=args.sparse_name,
sparse_category=args.sparse_category,
sparse_target=args.sparse_target,
release_version=args.release_version,
)

stub = args.stub

if args.save_dir:
model = Model(stub, download_path=args.save_dir)
Expand Down

0 comments on commit 1f94aa9

Please sign in to comment.