Skip to content

Commit

Permalink
[Add] sparsify.login CLI and function (#180)
Browse files Browse the repository at this point in the history
* Adding sparsify.login entrypoint and function

* Adding docstring to exception

* Adding pip install of sparsifyml

* Respond to review

* Adding help message at top

* Adding setup python to workflow

* Adding checked sparsifyml import

* Apply suggestions from code review

Co-authored-by: Danny Guinther <dannyguinther@gmail.com>

* check against major minor version only

* add client_id and other bug fixes

* Fix: `--index` --> `--index-url`

* Update install command missed during rebase

* * Clean up code
* Remove Global variables
* Update PyPi Server link
* Add Logging
* Move exceptions to their own file

* Style fixes

* Apply suggestions from code review

Add: suggestion from @KSGulin

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>

* Update src/sparsify/login.py

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>

* remove comment

---------

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: Danny Guinther <dannyguinther@gmail.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: rahul-tuli <rahul@neuralmagic.com>
Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
  • Loading branch information
6 people committed Apr 28, 2023
1 parent 36b64fd commit 47afad9
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 13 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _setup_entry_points() -> Dict:
return {
"console_scripts": [
"sparsify.run=sparsify.cli.run:main",
"sparsify.login=sparsify.login:main",
]
}

Expand Down
2 changes: 2 additions & 0 deletions src/sparsify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@

# flake8: noqa
# isort: skip_file

from .login import *
13 changes: 8 additions & 5 deletions src/sparsify/auto/utils/nm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@
Helper functions for communicating with the Neural Magic API
"""
import os
import warnings
from typing import Tuple

import requests

from sparsify.login import import_sparsifyml_authenticated
from sparsify.schemas import APIArgs, Metrics, SparsificationTrainingConfig
from sparsify.utils import get_base_url, strtobool


try:
from sparsifyml.auto import auto_training_config_initial, auto_training_config_tune
except (ImportError, ModuleNotFoundError):
warnings.warn("failed to import sparsifyml", ImportWarning)
sparsifyml = import_sparsifyml_authenticated()

from sparsifyml.auto import ( # noqa: E402
auto_training_config_initial,
auto_training_config_tune,
)


__all__ = ["api_request_config", "api_request_tune", "request_student_teacher_configs"]

Expand Down
160 changes: 160 additions & 0 deletions src/sparsify/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Usage: sparsify.login [OPTIONS]
sparsify.login utility to log into sparsify locally.
Options:
--api-key TEXT API key copied from your account. [required]
--version Show the version and exit. [default: False]
--help Show this message and exit. [default: False]
"""

import importlib
import json
import logging
import subprocess
import sys
from types import ModuleType
from typing import Optional

import click
from sparsezoo.analyze.cli import CONTEXT_SETTINGS
from sparsify.utils import (
credentials_exists,
get_access_token,
get_authenticated_pypi_url,
get_sparsify_credentials_path,
overwrite_credentials,
set_log_level,
)
from sparsify.utils.exceptions import SparsifyLoginRequired
from sparsify.version import version_major_minor


__all__ = [
"login",
"import_sparsifyml_authenticated",
"authenticate",
]


_LOGGER = logging.getLogger(__name__)


@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
"--api-key", type=str, help="API key copied from your account.", required=True
)
@click.version_option(version=version_major_minor)
@click.option("--debug/--no-debug", default=False, hidden=True)
def main(api_key: str, debug: bool = False):
"""
sparsify.login utility to log into sparsify locally.
"""
set_log_level(logger=_LOGGER, level=logging.DEBUG if debug else logging.INFO)
_LOGGER.info("Logging into sparsify...")

login(api_key=api_key)

_LOGGER.debug(f"locals: {locals()}")
_LOGGER.info("Logged in successfully, sparsify setup is complete.")


def login(api_key: str) -> None:
"""
Logs into sparsify.
:param api_key: The API key copied from your account
:raises InvalidAPIKey: if the API key is invalid
"""
access_token = get_access_token(api_key)
overwrite_credentials(api_key=api_key)
install_sparsifyml(access_token)


def install_sparsifyml(access_token: str) -> None:
"""
Installs `sparsifyml` from the authenticated pypi server, if not already
installed or if the version is not the same as the current version.
:param access_token: The access token to use for authentication
"""
sparsifyml_spec = importlib.util.find_spec("sparsifyml")
sparsifyml = importlib.import_module("sparsifyml") if sparsifyml_spec else None

sparsifyml_installed = (
sparsifyml_spec is not None
and sparsifyml.version_major_minor == version_major_minor
)

if not sparsifyml_installed:
_LOGGER.info(
f"Installing sparsifyml version {version_major_minor} "
"from neuralmagic pypi server"
)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--index-url",
get_authenticated_pypi_url(access_token=access_token),
f"sparsifyml_nightly>={version_major_minor}",
]
)
else:
_LOGGER.info(
f"sparsifyml version {version_major_minor} is already installed, "
"skipping installation from neuralmagic pypi server"
)


def import_sparsifyml_authenticated() -> Optional[ModuleType]:
"""
Authenticates and imports sparsifyml.
"""
authenticate()
import sparsifyml

return sparsifyml


def authenticate() -> None:
"""
Authenticates with sparsify server using the credentials stored on disk.
:raises SparsifyLoginRequired: if no valid credentials are found
"""
if not credentials_exists():
raise SparsifyLoginRequired(
"No valid sparsify credentials found. Please run `sparsify.login`"
)

with get_sparsify_credentials_path.open() as fp:
credentials = json.load(fp)

if "api_key" not in credentials:
raise SparsifyLoginRequired(
"No valid sparsify credentials found. Please run `sparsify.login`"
)

login(api_key=credentials["api_key"])


if __name__ == "__main__":
main()
11 changes: 3 additions & 8 deletions src/sparsify/one_shot/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@
import argparse
from pathlib import Path

from sparsify.login import import_sparsifyml_authenticated
from sparsify.utils import constants


try:
from sparsifyml import one_shot
except ImportError as e:

class SparsifyLoginRequired(Exception):
"""Exception when sparsifyml has not been installed by sparsify.login"""

raise SparsifyLoginRequired("Use `sparsify.login` to enable this command.") from e
sparsifyml = import_sparsifyml_authenticated()
from sparsifyml import one_shot # noqa: E402


__all__ = [
Expand Down
21 changes: 21 additions & 0 deletions src/sparsify/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class InvalidAPIKey(Exception):
"""The API key was invalid"""


class SparsifyLoginRequired(Exception):
"""Run `sparsify.login`"""
112 changes: 112 additions & 0 deletions src/sparsify/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import logging
from pathlib import Path

import requests

from sparsify.utils.exceptions import InvalidAPIKey


__all__ = [
"credentials_exists",
"get_access_token",
"get_authenticated_pypi_url",
"get_sparsify_credentials_path",
"get_token_url",
"overwrite_credentials",
"set_log_level",
"strtobool",
]
_MAP = {
Expand All @@ -30,6 +47,8 @@
"0": False,
}

_LOGGER = logging.getLogger(__name__)


def strtobool(value):
"""
Expand All @@ -43,3 +62,96 @@ def strtobool(value):
return _MAP[str(value).lower()]
except KeyError:
raise ValueError('"{}" is not a valid bool value'.format(value))


def get_token_url():
"""
:return: The url to use for getting an access token
"""
return "https://accounts.neuralmagic.com/v1/connect/token"


def get_sparsify_credentials_path() -> Path:
"""
:return: The path to the neuralmagic credentials file
"""
return Path.home().joinpath(".config", "neuralmagic", "credentials.json")


def credentials_exists() -> bool:
"""
:return: True if the credentials file exists, False otherwise
"""
return get_sparsify_credentials_path().exists()


def overwrite_credentials(api_key: str) -> None:
"""
Overwrite the credentials file with the given api key
or create a new file if it does not exist
:param api_key: The api key to write to the credentials file
"""
credentials_path = get_sparsify_credentials_path()
credentials_path.parent.mkdir(parents=True, exist_ok=True)
credentials = {"api_key": api_key}

with credentials_path.open("w") as fp:
json.dump(credentials, fp)


def get_access_token(api_key: str) -> str:
"""
Get the access token for the given api key
:param api_key: The api key to use for authentication
:return: The requested access token
"""
response = requests.post(
get_token_url(),
data={
"grant_type": "password",
"username": "api-key",
"client_id": "ee910196-cd8a-11ed-b74d-bb563cd16e9d",
"password": api_key,
"scope": "pypi:read",
},
)

try:
response.raise_for_status()
except requests.HTTPError as http_error:
error_message = (
"Sorry, we were unable to authenticate your Neural Magic Account API key. "
"If you believe this is a mistake, contact support@neuralmagic.com "
"to help remedy this issue."
)
raise InvalidAPIKey(error_message) from http_error

if response.status_code != 200:
raise ValueError(f"Unknown response code {response.status_code}")

_LOGGER.info("Successfully authenticated with Neural Magic Account API key")
return response.json()["access_token"]


def get_authenticated_pypi_url(access_token: str) -> str:
"""
Get the authenticated pypi url for the given access token
:return: The authenticated pypi url
"""
pypi_url_template = "https://nm:{}@pypi.neuralmagic.com"
return pypi_url_template.format(access_token)


def set_log_level(logger: logging.Logger, level: int) -> None:
"""
Set the log level for the given logger and all of its handlers
:param logger: The logger to set the level for
:param level: The level to set the logger to
"""
logging.basicConfig(level=level)
for handler in logger.handlers:
handler.setLevel(level=level)

0 comments on commit 47afad9

Please sign in to comment.