diff --git a/CHANGELOG.D/2405.feature b/CHANGELOG.D/2405.feature new file mode 100644 index 000000000..4beb7e9b8 --- /dev/null +++ b/CHANGELOG.D/2405.feature @@ -0,0 +1 @@ +Configure version checker settings by plugins. diff --git a/neuro-cli/setup.cfg b/neuro-cli/setup.cfg index 66e3e5b8f..b4960bb16 100644 --- a/neuro-cli/setup.cfg +++ b/neuro-cli/setup.cfg @@ -57,11 +57,14 @@ install_requires = rich>=10.0.1 packaging>=20.0 jedi>=0.16 + importlib-metadata>=3.6; python_version<"3.10" [options.entry_points] console_scripts = neuro = neuro_cli.main:main docker-credential-neuro = neuro_cli.docker_credential_helper:main +neuro_api = + neuro-cli=neuro_cli.plugin:setup [options.packages.find] where=src diff --git a/neuro-cli/src/neuro_cli/plugin.py b/neuro-cli/src/neuro_cli/plugin.py new file mode 100644 index 000000000..f84243983 --- /dev/null +++ b/neuro-cli/src/neuro_cli/plugin.py @@ -0,0 +1,38 @@ +from neuro_sdk import ConfigScope, PluginManager + +NEURO_CLI_UPGRADE = """\ +You are using Neuro Platform Client {old_ver}, however {new_ver} is available. +You should consider upgrading via the following command: + python -m pip install --upgrade neuro-cli +""" + + +def get_neuro_cli_txt(old: str, new: str) -> str: + return NEURO_CLI_UPGRADE.format(old_ver=old, new_ver=new) + + +CERTIFI_UPGRADE = """\ +Your root certificates are out of date. +You are using certifi {old_ver}, however {new_ver} is available. +Please consider upgrading certifi package, e.g.: + python -m pip install --upgrade certifi +or + conda update certifi +""" + + +def get_certifi_txt(old: str, new: str) -> str: + return CERTIFI_UPGRADE.format(old_ver=old, new_ver=new) + + +def setup(manager: PluginManager) -> None: + # Setup config options + manager.config.define_str("job", "ps-format") + manager.config.define_str("job", "top-format") + manager.config.define_str("job", "life-span") + manager.config.define_str("job", "cluster-name", scope=ConfigScope.LOCAL) + manager.config.define_str_list("storage", "cp-exclude") + manager.config.define_str_list("storage", "cp-exclude-from-files") + + manager.version_checker.register("neuro-cli", get_neuro_cli_txt) + manager.version_checker.register("certifi", get_certifi_txt, delay=14 * 3600 * 24) diff --git a/neuro-cli/src/neuro_cli/root.py b/neuro-cli/src/neuro_cli/root.py index 70bd0fbde..4e581c6f3 100644 --- a/neuro-cli/src/neuro_cli/root.py +++ b/neuro-cli/src/neuro_cli/root.py @@ -27,7 +27,7 @@ from rich.text import Text as RichText from neuro_sdk import Client, ConfigError, Factory, gen_trace_id -from neuro_sdk.config import _ConfigData, load_user_config +from neuro_sdk.config import _ConfigData from .asyncio_utils import Runner @@ -190,7 +190,7 @@ async def get_user_config(self) -> Mapping[str, Any]: try: client = await self.init_client() except ConfigError: - return load_user_config(self.config_path.expanduser()) + return await self.factory.load_user_config() else: return await client.config.get_user_config() diff --git a/neuro-cli/src/neuro_cli/utils.py b/neuro-cli/src/neuro_cli/utils.py index fdd23818e..b9a46017c 100644 --- a/neuro-cli/src/neuro_cli/utils.py +++ b/neuro-cli/src/neuro_cli/utils.py @@ -36,7 +36,6 @@ from .parse_utils import parse_timedelta from .root import Root from .stats import upload_gmp_stats -from .version_utils import run_version_checker log = logging.getLogger(__name__) @@ -58,9 +57,17 @@ async def _run_async_function( if init_client: await root.init_client() - pypi_task: "asyncio.Task[None]" = loop.create_task( - run_version_checker(root.client, root.disable_pypi_version_check) - ) + if not root.disable_pypi_version_check: + msgs = await root.client.version_checker.get_outdated() + for msg in msgs.values(): + root.err_console.print(msg, style="yellow") + + pypi_task: "asyncio.Task[None]" = loop.create_task( + root.client.version_checker.update() + ) + else: + pypi_task = loop.create_task(asyncio.sleep(0)) # do nothing + stats_task: "asyncio.Task[None]" = loop.create_task( upload_gmp_stats( root.client, root.command_path, root.command_params, root.skip_gmp_stats diff --git a/neuro-cli/src/neuro_cli/version_utils.py b/neuro-cli/src/neuro_cli/version_utils.py deleted file mode 100644 index d5bcdfec6..000000000 --- a/neuro-cli/src/neuro_cli/version_utils.py +++ /dev/null @@ -1,207 +0,0 @@ -import contextlib -import logging -import sqlite3 -import time -from typing import Any, Dict, List, Optional, Tuple - -import aiohttp -import certifi -import click -import dateutil.parser -from packaging.version import parse as parse_version -from typing_extensions import TypedDict -from yarl import URL - -from neuro_sdk import Client - -import neuro_cli - - -class Record(TypedDict): - package: str - version: str - uploaded: float - checked: float - - -log = logging.getLogger(__name__) - - -SCHEMA = { - "pypi": "CREATE TABLE pypi " - "(package TEXT, version TEXT, uploaded REAL, checked REAL)", -} -DROP = {"pypi": "DROP TABLE IF EXISTS pypi"} - - -async def run_version_checker(client: Client, disable_check: bool) -> None: - if disable_check: - return - with client.config._open_db() as db: - _ensure_schema(db) - neurocli_db = _read_package(db, "neuro-cli") - certifi_db = _read_package(db, "certifi") - - _warn_maybe(neurocli_db, certifi_db) - inserts: List[Tuple[str, str, float, float]] = [] - await _add_record(client, "neuro-cli", neurocli_db, inserts) - await _add_record(client, "certifi", certifi_db, inserts) - with client.config._open_db() as db: - db.executemany( - """ - INSERT INTO pypi (package, version, uploaded, checked) - VALUES (?, ?, ?, ?) - """, - inserts, - ) - db.execute("DELETE FROM pypi WHERE checked < ?", (time.time() - 7 * 24 * 3600,)) - with contextlib.suppress(sqlite3.OperationalError): - db.commit() - - -async def _add_record( - client: Client, - package: str, - record: Optional[Record], - inserts: List[Tuple[str, str, float, float]], -) -> None: - if record is None or time.time() - record["checked"] > 10 * 60: - pypi = await _fetch_package(client._session, package) - if pypi is None: - return - inserts.append( - ( - pypi["package"], - pypi["version"], - pypi["uploaded"], - pypi["checked"], - ) - ) - - -def _ensure_schema(db: sqlite3.Connection) -> None: - cur = db.cursor() - ok = True - found = set() - cur.execute("SELECT type, name, sql from sqlite_master") - for type, name, sql in cur: - if type not in ("table", "index"): - continue - if name in SCHEMA: - if SCHEMA[name] != sql: - ok = False - break - else: - found.add(name) - - if not ok or found < SCHEMA.keys(): - for sql in reversed(list(DROP.values())): - cur.execute(sql) - for sql in SCHEMA.values(): - cur.execute(sql) - - -READ_PACKAGE = """ - SELECT package, version, uploaded, checked - FROM pypi - WHERE package = ? - ORDER BY checked - LIMIT 1 -""" - - -def _read_package(db: sqlite3.Connection, package: str) -> Optional[Record]: - cur = db.execute(READ_PACKAGE, (package,)) - return cur.fetchone() - - -async def _fetch_package( - session: aiohttp.ClientSession, package: str -) -> Optional[Record]: - url = URL(f"https://pypi.org/pypi/{package}/json") - async with session.get(url) as resp: - if resp.status != 200: - log.debug("%s status on fetching %s", resp.status, url) - return None - pypi_response = await resp.json() - version = _parse_max_version(pypi_response) - if version is None: - return None - uploaded = _parse_version_upload_time(pypi_response, version) - return { - "package": package, - "version": version, - "uploaded": uploaded, - "checked": time.time(), - } - - -def _parse_date(value: str) -> float: - # from format: "2019-08-19" - return dateutil.parser.parse(value).timestamp() - - -def _parse_max_version(pypi_response: Dict[str, Any]) -> Optional[str]: - try: - ret = [version for version in pypi_response["releases"].keys()] - return max(ver for ver in ret if not parse_version(ver).is_prerelease) - except (KeyError, ValueError): - return None - - -def _parse_version_upload_time( - pypi_response: Dict[str, Any], target_version: str -) -> float: - try: - dates = [ - _parse_date(info["upload_time"]) - for version, info_list in pypi_response["releases"].items() - for info in info_list - if version == target_version - ] - return max(dates) - except (KeyError, ValueError): - return 0 - - -def _warn_maybe( - neurocli_db: Optional[Record], - certifi_db: Optional[Record], - *, - certifi_warning_delay: int = 14 * 3600 * 24, -) -> None: - - if neurocli_db is not None: - current = parse_version(neuro_cli.__version__) - pypi = parse_version(neurocli_db["version"]) - if current < pypi: - update_command = "pip install --upgrade neuro-cli" - click.secho( - f"You are using Neuro Platform Client {current}, " - f"however {pypi} is available.\n" - f"You should consider upgrading via " - f"the '{update_command}' command.", - err=True, - fg="yellow", - ) - - if certifi_db is not None: - current = parse_version(certifi.__version__) # type: ignore - pypi = parse_version(certifi_db["version"]) - if ( - current < pypi - and time.time() - certifi_db["uploaded"] > certifi_warning_delay - ): - pip_update_command = "pip install --upgrade certifi" - conda_update_command = "conda update certifi" - click.secho( - f"Your root certificates are out of date.\n" - f"You are using certifi {current}, " - f"however {pypi} is available.\n" - f"Please consider upgrading certifi package, e.g.\n" - f" {pip_update_command}\n" - f"or\n" - f" {conda_update_command}", - err=True, - fg="red", - ) diff --git a/neuro-cli/tests/conftest.py b/neuro-cli/tests/conftest.py index cbd3804a5..30e9daf7b 100644 --- a/neuro-cli/tests/conftest.py +++ b/neuro-cli/tests/conftest.py @@ -10,7 +10,7 @@ from jose import jwt from yarl import URL -from neuro_sdk import Client, Cluster, Preset, __version__ +from neuro_sdk import Client, Cluster, PluginManager, Preset, __version__ from neuro_sdk.config import _AuthConfig, _AuthToken, _ConfigData, _save from neuro_sdk.tracing import _make_trace_config @@ -114,7 +114,8 @@ def go( registry_url: str = "https://registry-dev.neu.ro", trace_id: str = "bd7a977555f6b982", clusters: Optional[Dict[str, Cluster]] = None, - token_url: Optional[URL] = None + token_url: Optional[URL] = None, + plugin_manager: Optional[PluginManager] = None, ) -> Client: url = URL(url_str) if clusters is None: @@ -155,6 +156,8 @@ def go( real_auth_config = replace(auth_config, token_url=token_url) else: real_auth_config = auth_config + if plugin_manager is None: + plugin_manager = PluginManager() config = _ConfigData( auth_config=real_auth_config, auth_token=_AuthToken.create_non_expiring(token), @@ -167,6 +170,12 @@ def go( config_dir = tmp_path / ".neuro" _save(config, config_dir) session = aiohttp.ClientSession(trace_configs=[_make_trace_config()]) - return Client._create(session, config_dir, trace_id) + return Client._create( + session, + config_dir, + trace_id, + None, + plugin_manager=plugin_manager, + ) return go diff --git a/neuro-cli/tests/unit/test_job.py b/neuro-cli/tests/unit/test_job.py index 8665fb2ad..e964a2de7 100644 --- a/neuro-cli/tests/unit/test_job.py +++ b/neuro-cli/tests/unit/test_job.py @@ -17,6 +17,7 @@ JobRestartPolicy, JobStatus, JobStatusHistory, + PluginManager, RemoteImage, Resources, SecretFile, @@ -232,8 +233,13 @@ async def test_calc_top_columns_section_doesnt_exist( async def test_calc_ps_columns_user_spec( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "ps-format") - async with make_client("https://example.com") as client: + async with make_client( + "https://example.com", + plugin_manager=plugin_manager, + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" # empty config @@ -247,8 +253,12 @@ async def test_calc_ps_columns_user_spec( async def test_calc_top_columns_user_spec( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "top-format") - async with make_client("https://example.com") as client: + async with make_client( + "https://example.com", plugin_manager=plugin_manager + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" # empty config diff --git a/neuro-cli/tests/unit/test_storage.py b/neuro-cli/tests/unit/test_storage.py index 32ff16b84..793c4207d 100644 --- a/neuro-cli/tests/unit/test_storage.py +++ b/neuro-cli/tests/unit/test_storage.py @@ -3,7 +3,7 @@ import toml -from neuro_sdk import Client +from neuro_sdk import Client, PluginManager from neuro_cli.storage import calc_filters, calc_ignore_file_names @@ -25,8 +25,13 @@ async def test_calc_filters_section_doesnt_exist( async def test_calc_filters_user_spec( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str_list("storage", "cp-exclude") - async with make_client("https://example.com") as client: + async with make_client( + "https://example.com", + plugin_manager=plugin_manager, + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" local_conf.write_text( @@ -41,8 +46,13 @@ async def test_calc_filters_user_spec( async def test_calc_filters_user_spec_and_options( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str_list("storage", "cp-exclude") - async with make_client("https://example.com") as client: + async with make_client( + "https://example.com", + plugin_manager=plugin_manager, + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" local_conf.write_text( @@ -60,7 +70,13 @@ async def test_calc_filters_user_spec_and_options( async def test_calc_ignore_file_names_default( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: - async with make_client("https://example.com") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str_list("storage", "cp-exclude-from-files") + + async with make_client( + "https://example.com", + plugin_manager=plugin_manager, + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" # empty config @@ -73,7 +89,13 @@ async def test_calc_ignore_file_names_default( async def test_calc_ignore_file_names_user_spec( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: - async with make_client("https://example.com") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str_list("storage", "cp-exclude-from-files") + + async with make_client( + "https://example.com", + plugin_manager=plugin_manager, + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" local_conf.write_text( diff --git a/neuro-cli/tests/unit/test_utils.py b/neuro-cli/tests/unit/test_utils.py index c6c61ee4e..4c412b254 100644 --- a/neuro-cli/tests/unit/test_utils.py +++ b/neuro-cli/tests/unit/test_utils.py @@ -9,7 +9,7 @@ from aiohttp import web from yarl import URL -from neuro_sdk import Action, Client, JobStatus +from neuro_sdk import Action, Client, JobStatus, PluginManager from neuro_cli.parse_utils import parse_timedelta from neuro_cli.root import Root @@ -702,7 +702,12 @@ def test_pager_maybe_terminal_smaller() -> None: async def test_calc_life_span_none_default( monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: - async with make_client("https://example.com") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "life-span") + + async with make_client( + "https://example.com", plugin_manager=plugin_manager + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" local_conf.write_text(toml.dumps({"job": {"life-span": "1d2h3m4s"}})) @@ -715,7 +720,12 @@ async def test_calc_life_span_none_default( async def test_calc_life_span_default_life_span_all_keys( caplog: Any, monkeypatch: Any, tmp_path: Path, make_client: _MakeClient ) -> None: - async with make_client("https://example.com") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "life-span") + + async with make_client( + "https://example.com", plugin_manager=plugin_manager + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" # empty config @@ -733,7 +743,12 @@ async def test_calc_default_life_span_invalid( tmp_path: Path, make_client: _MakeClient, ) -> None: - async with make_client("https://example.com") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "life-span") + + async with make_client( + "https://example.com", plugin_manager=plugin_manager + ) as client: monkeypatch.chdir(tmp_path) local_conf = tmp_path / ".neuro.toml" # empty config diff --git a/neuro-sdk/src/neuro_sdk/__init__.py b/neuro-sdk/src/neuro_sdk/__init__.py index accbbecd1..45b6a1bc1 100644 --- a/neuro-sdk/src/neuro_sdk/__init__.py +++ b/neuro-sdk/src/neuro_sdk/__init__.py @@ -77,7 +77,7 @@ VolumeParseResult, ) from .parsing_utils import LocalImage, RemoteImage, Tag, TagOption -from .plugins import ConfigBuilder, PluginManager +from .plugins import ConfigBuilder, ConfigScope, PluginManager, VersionChecker from .secrets import Secret, Secrets from .server_cfg import Cluster from .service_accounts import ServiceAccount, ServiceAccounts @@ -112,6 +112,7 @@ "Config", "ConfigBuilder", "ConfigError", + "ConfigScope", "Container", "DEFAULT_API_URL", "DEFAULT_CONFIG_PATH", @@ -168,6 +169,7 @@ "Tag", "TagOption", "Users", + "VersionChecker", "Volume", "VolumeParseResult", "find_project_root", diff --git a/neuro-sdk/src/neuro_sdk/_version_utils.py b/neuro-sdk/src/neuro_sdk/_version_utils.py new file mode 100644 index 000000000..a0a87d8af --- /dev/null +++ b/neuro-sdk/src/neuro_sdk/_version_utils.py @@ -0,0 +1,205 @@ +import contextlib +import logging +import sqlite3 +import sys +import time +from typing import Any, Dict, List, Optional, Tuple + +import dateutil.parser +from packaging.version import parse as parse_version +from typing_extensions import TypedDict +from yarl import URL + +from .config import Config +from .core import _Core +from .plugins import PluginManager +from .utils import NoPublicConstructor + +if sys.version_info >= (3, 10): + from importlib.metadata import version +else: + from importlib_metadata import version + + +class _Record(TypedDict): + package: str + version: str + uploaded: float + checked: float + + +log = logging.getLogger(__package__) + + +class VersionChecker(metaclass=NoPublicConstructor): + _SCHEMA = { + "pypi": "CREATE TABLE pypi " + "(package TEXT, version TEXT, uploaded REAL, checked REAL)", + } + _DROP = {"pypi": "DROP TABLE IF EXISTS pypi"} + _READ_PACKAGE = """ + SELECT package, version, uploaded, checked + FROM pypi + WHERE package = ? + ORDER BY checked + LIMIT 1 + """ + + def __init__( + self, core: _Core, config: Config, plugin_manager: PluginManager + ) -> None: + self._core = core + self._config = config + self._plugin_manager = plugin_manager + self._records: Dict[str, _Record] = {} + self._loaded = False + + async def get_outdated(self) -> Dict[str, str]: + """Get packages that can be updated along with instructions for update. + + The information is collected from local database, updated by previous run. + """ + await self._read_db() + ret = {} + for package, record in self._records.items(): + assert package == record["package"] + spec = self._plugin_manager.version_checker._records.get(package) + if spec is None: + continue + current = parse_version(version(package)) + pypi = parse_version(record["version"]) + if current < pypi and time.time() - record["uploaded"] > spec.delay: + new_text = spec.update_text(str(current), str(pypi)) # type: ignore + if spec.exclusive: + return {package: new_text} + else: + ret[package] = new_text + return ret + + async def update(self) -> None: + """Update local database with packages information fetched from pypi""" + await self._read_db() + inserts: List[Tuple[str, str, float, float]] = [] + for package in self._plugin_manager.version_checker._records: + record = self._records.get(package) + await self._update_record(package, record, inserts) + + with self._config._open_db() as db: + db.executemany( + """ + INSERT INTO pypi (package, version, uploaded, checked) + VALUES (?, ?, ?, ?) + """, + inserts, + ) + db.execute( + "DELETE FROM pypi WHERE checked < ?", + (time.time() - 7 * 24 * 3600,), + ) + with contextlib.suppress(sqlite3.OperationalError): + db.commit() + + async def _read_db(self) -> None: + if self._loaded: + return + with self._config._open_db() as db: + self._ensure_schema(db) + for package in self._plugin_manager.version_checker._records: + record = self._read_package(db, package) + if record is not None: + self._records[package] = record + self._loaded = True + + async def _update_record( + self, + package: str, + record: Optional[_Record], + inserts: List[Tuple[str, str, float, float]], + ) -> None: + if record is None or time.time() - record["checked"] > 10 * 60: + pypi = await self._fetch_package(package) + if pypi is None: + return + inserts.append( + ( + pypi["package"], + pypi["version"], + pypi["uploaded"], + pypi["checked"], + ) + ) + self._records[pypi["package"]] = pypi + + def _ensure_schema(self, db: sqlite3.Connection) -> None: + cur = db.cursor() + ok = True + found = set() + cur.execute("SELECT type, name, sql from sqlite_master") + for type, name, sql in cur: + if type not in ("table", "index"): + continue + if name in self._SCHEMA: + if self._SCHEMA[name] != sql: + ok = False + break + else: + found.add(name) + + if not ok or found < self._SCHEMA.keys(): + for sql in reversed(list(self._DROP.values())): + cur.execute(sql) + for sql in self._SCHEMA.values(): + cur.execute(sql) + + def _read_package(self, db: sqlite3.Connection, package: str) -> Optional[_Record]: + cur = db.execute(self._READ_PACKAGE, (package,)) + return cur.fetchone() + + async def _fetch_package( + self, + package: str, + ) -> Optional[_Record]: + url = URL(f"https://pypi.org/pypi/{package}/json") + async with self._core._session.get(url) as resp: + if resp.status != 200: + log.debug("%s status on fetching %s", resp.status, url) + return None + pypi_response = await resp.json() + ver = _parse_max_version(pypi_response) + if ver is None: + return None + uploaded = _parse_version_upload_time(pypi_response, ver) + return { + "package": package, + "version": ver, + "uploaded": uploaded, + "checked": time.time(), + } + + +def _parse_date(value: str) -> float: + # from format: "2019-08-19" + return dateutil.parser.parse(value).timestamp() + + +def _parse_max_version(pypi_response: Dict[str, Any]) -> Optional[str]: + try: + ret = [ver1 for ver1 in pypi_response["releases"].keys()] + return max(ver2 for ver2 in ret if not parse_version(ver2).is_prerelease) + except (KeyError, ValueError): + return None + + +def _parse_version_upload_time( + pypi_response: Dict[str, Any], target_version: str +) -> float: + try: + dates = [ + _parse_date(info["upload_time"]) + for ver, info_list in pypi_response["releases"].items() + for info in info_list + if ver == target_version + ] + return max(dates) + except (KeyError, ValueError): + return 0 diff --git a/neuro-sdk/src/neuro_sdk/client.py b/neuro-sdk/src/neuro_sdk/client.py index aee0b47ac..92c7b35d0 100644 --- a/neuro-sdk/src/neuro_sdk/client.py +++ b/neuro-sdk/src/neuro_sdk/client.py @@ -6,6 +6,7 @@ from neuro_sdk.service_accounts import ServiceAccounts +from ._version_utils import VersionChecker from .admin import _Admin from .buckets import Buckets from .config import Config @@ -14,6 +15,7 @@ from .images import Images from .jobs import Jobs from .parser import Parser +from .plugins import PluginManager from .secrets import Secrets from .server_cfg import Preset from .storage import Storage @@ -27,12 +29,14 @@ def __init__( session: aiohttp.ClientSession, path: Path, trace_id: Optional[str], - trace_sampled: Optional[bool] = None, + trace_sampled: Optional[bool], + plugin_manager: PluginManager, ) -> None: self._closed = False self._session = session + self._plugin_manager = plugin_manager self._core = _Core(session, trace_id, trace_sampled) - self._config = Config._create(self._core, path) + self._config = Config._create(self._core, path, plugin_manager) # Order does matter, need to check the main config before loading # the storage cookie session @@ -51,6 +55,9 @@ def __init__( self._service_accounts = ServiceAccounts._create(self._core, self._config) self._buckets = Buckets._create(self._core, self._config, self._parser) self._images: Optional[Images] = None + self._version_checker: VersionChecker = VersionChecker._create( + self._core, self._config, plugin_manager + ) async def close(self) -> None: if self._closed: @@ -129,3 +136,7 @@ def buckets(self) -> Buckets: @property def parse(self) -> Parser: return self._parser + + @property + def version_checker(self) -> VersionChecker: + return self._version_checker diff --git a/neuro-sdk/src/neuro_sdk/config.py b/neuro-sdk/src/neuro_sdk/config.py index 6d4c0f4d3..1d1ffa69f 100644 --- a/neuro-sdk/src/neuro_sdk/config.py +++ b/neuro-sdk/src/neuro_sdk/config.py @@ -20,16 +20,10 @@ from .core import _Core from .errors import ConfigError from .login import AuthTokenClient, _AuthConfig, _AuthToken -from .plugins import PluginManager +from .plugins import ConfigScope, PluginManager, _ParamType from .server_cfg import Cluster, Preset, _ServerConfig, get_server_config from .utils import NoPublicConstructor, find_project_root, flat -if sys.version_info >= (3, 10): - from importlib.metadata import entry_points -else: - from importlib_metadata import entry_points - - WIN32 = sys.platform == "win32" CMD_RE = re.compile("[A-Za-z][A-Za-z0-9-]*") @@ -75,9 +69,10 @@ class _ConfigRecoveryData: class Config(metaclass=NoPublicConstructor): - def __init__(self, core: _Core, path: Path) -> None: + def __init__(self, core: _Core, path: Path, plugin_manager: PluginManager) -> None: self._core = core self._path = path + self._plugin_manager = plugin_manager self.__config_data: Optional[_ConfigData] = None def _load(self) -> _ConfigData: @@ -251,10 +246,10 @@ async def _registry_auth(self) -> str: ).decode("ascii") async def get_user_config(self) -> Mapping[str, Any]: - return load_user_config(self._path) + return _load_user_config(self._plugin_manager, self._path) def _get_user_config(self) -> Mapping[str, Any]: - return load_user_config(self._path) + return _load_user_config(self._plugin_manager, self._path) @contextlib.contextmanager def _open_db(self, suppress_errors: bool = True) -> Iterator[sqlite3.Connection]: @@ -262,7 +257,7 @@ def _open_db(self, suppress_errors: bool = True) -> Iterator[sqlite3.Connection] yield db -def load_user_config(path: Path) -> Mapping[str, Any]: +def _load_user_config(plugin_manager: PluginManager, path: Path) -> Mapping[str, Any]: # TODO: search in several locations (HOME+curdir), # merge found configs filename = path / "user.toml" @@ -272,14 +267,14 @@ def load_user_config(path: Path) -> Mapping[str, Any]: elif not filename.is_file(): raise ConfigError(f"User config {filename} should be a regular file") else: - config = _load_file(filename, allow_cluster_name=False) + config = _load_file(plugin_manager, filename, allow_cluster_name=False) try: project_root = find_project_root() except ConfigError: return config else: filename = project_root / ".neuro.toml" - local_config = _load_file(filename, allow_cluster_name=True) + local_config = _load_file(plugin_manager, filename, allow_cluster_name=True) return _merge_user_configs(config, local_config) @@ -646,7 +641,7 @@ def _check_item( def _check_section( config: Mapping[str, Any], section: str, - params: Dict[str, Any], + params: Mapping[str, _ParamType], filename: Union[str, "os.PathLike[str]"], ) -> None: sec = config.get(section) @@ -664,6 +659,7 @@ def _check_section( def _validate_user_config( + plugin_manager: PluginManager, config: Mapping[str, Any], filename: Union[str, "os.PathLike[str]"], allow_cluster_name: bool = False, @@ -676,25 +672,16 @@ def _validate_user_config( # # Since currently CLI is the only API client that reads user config data, API # validates it. - plugin_manager = PluginManager() - plugin_manager.config.define_str("job", "ps-format") - plugin_manager.config.define_str("job", "top-format") - plugin_manager.config.define_str("job", "life-span") - if allow_cluster_name: - plugin_manager.config.define_str("job", "cluster-name") - else: + if not allow_cluster_name: if "cluster-name" in config.get("job", {}): raise ConfigError( f"{filename}: cluster name is not allowed in global user " f"config file, use 'neuro config switch-cluster' for " f"changing the default cluster name" ) - - plugin_manager.config.define_str_list("storage", "cp-exclude") - plugin_manager.config.define_str_list("storage", "cp-exclude-from-files") - for entry_point in entry_points(group="neuro_api"): - entry_point.load()(plugin_manager) - config_spec = plugin_manager.config._get_spec() + config_spec = plugin_manager.config._get_spec( + ConfigScope.GLOBAL if not allow_cluster_name else ConfigScope.ALL + ) # Alias section uses different validation _check_sections(config, set(config_spec.keys()) | {"alias"}, filename) @@ -721,12 +708,16 @@ def _validate_alias( pass -def _load_file(filename: Path, allow_cluster_name: bool) -> Mapping[str, Any]: +def _load_file( + plugin_manager: PluginManager, filename: Path, allow_cluster_name: bool +) -> Mapping[str, Any]: try: config = toml.load(filename) except ValueError as exc: raise ConfigError(f"{filename}: {exc}") - _validate_user_config(config, filename, allow_cluster_name=allow_cluster_name) + _validate_user_config( + plugin_manager, config, filename, allow_cluster_name=allow_cluster_name + ) return config diff --git a/neuro-sdk/src/neuro_sdk/config_factory.py b/neuro-sdk/src/neuro_sdk/config_factory.py index 19b017752..518dd316c 100644 --- a/neuro-sdk/src/neuro_sdk/config_factory.py +++ b/neuro-sdk/src/neuro_sdk/config_factory.py @@ -5,7 +5,7 @@ import ssl import sys from pathlib import Path -from typing import Awaitable, Callable, List, Optional +from typing import Any, Awaitable, Callable, List, Mapping, Optional import aiohttp import certifi @@ -14,14 +14,20 @@ from neuro_sdk.login import AuthTokenClient from .client import Client -from .config import _ConfigData, _load, _load_recovery_data, _save +from .config import _ConfigData, _load, _load_recovery_data, _load_user_config, _save from .core import DEFAULT_TIMEOUT from .errors import ConfigError from .login import AuthNegotiator, HeadlessNegotiator, _AuthToken, logout_from_browser +from .plugins import PluginManager from .server_cfg import _ServerConfig, get_server_config from .tracing import _make_trace_config from .utils import _ContextManager +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points +else: + from importlib_metadata import entry_points + DEFAULT_CONFIG_PATH = "~/.neuro" CONFIG_ENV_NAME = "NEUROMATION_CONFIG" PASS_CONFIG_ENV_NAME = "NEURO_PASSED_CONFIG" @@ -68,6 +74,9 @@ def __init__( self._trace_configs += trace_configs self._trace_id = trace_id self._trace_sampled = trace_sampled + self._plugin_manager = PluginManager() + for entry_point in entry_points(group="neuro_api"): + entry_point.load()(self._plugin_manager) @property def path(self) -> Path: @@ -95,7 +104,11 @@ async def _get(self, *, timeout: aiohttp.ClientTimeout = DEFAULT_TIMEOUT) -> Cli session = await _make_session(timeout, self._trace_configs) try: client = Client._create( - session, self._path, self._trace_id, self._trace_sampled + session, + self._path, + self._trace_id, + self._trace_sampled, + self._plugin_manager, ) await client.config.check_server() except (asyncio.CancelledError, Exception): @@ -262,5 +275,8 @@ async def logout( # Directory Not Empty or Not A Directory pass + async def load_user_config(self) -> Mapping[str, Any]: + return _load_user_config(self._plugin_manager, self._path) + def _save(self, config: _ConfigData) -> None: _save(config, self._path, False) diff --git a/neuro-sdk/src/neuro_sdk/plugins.py b/neuro-sdk/src/neuro_sdk/plugins.py index 72cb1ce98..6b80c647b 100644 --- a/neuro-sdk/src/neuro_sdk/plugins.py +++ b/neuro-sdk/src/neuro_sdk/plugins.py @@ -1,11 +1,31 @@ +import enum import numbers -from typing import Any, Dict, List, Mapping, Tuple, Type, Union +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Mapping, Tuple, Type, Union from .errors import ConfigError +class ConfigScope(enum.Flag): + GLOBAL = enum.auto() + LOCAL = enum.auto() + ALL = GLOBAL | LOCAL + + +_ParamType = Union[ + Type[bool], + Type[numbers.Real], + Type[numbers.Integral], + Type[str], + Tuple[Type[List[Any]], Type[bool]], + Tuple[Type[List[Any]], Type[str]], + Tuple[Type[List[Any]], Type[numbers.Real]], + Tuple[Type[List[Any]], Type[numbers.Integral]], +] + + class ConfigBuilder: - _config_spec: Dict[str, Any] + _config_spec: Dict[str, Dict[str, Tuple[_ParamType, ConfigScope]]] def __init__(self) -> None: self._config_spec = dict() @@ -14,59 +34,105 @@ def _define_param( self, section: str, name: str, - type: Union[ - Type[bool], - Type[numbers.Real], - Type[numbers.Integral], - Type[str], - Tuple[Type[List[Any]], Type[bool]], - Tuple[Type[List[Any]], Type[str]], - Tuple[Type[List[Any]], Type[numbers.Real]], - Tuple[Type[List[Any]], Type[numbers.Integral]], - ], + type: _ParamType, + scope: ConfigScope, ) -> None: if section == "alias": raise ConfigError("Registering aliases is not supported yet.") if section in self._config_spec and name in self._config_spec[section]: raise ConfigError(f"Config parameter {section}.{name} already registered") self._config_spec.setdefault(section, dict()) - self._config_spec[section][name] = type + self._config_spec[section][name] = (type, scope) + + def _get_spec( + self, scope: ConfigScope = ConfigScope.ALL + ) -> Mapping[str, Mapping[str, _ParamType]]: + return { + section: {name: val[0] for name, val in body.items() if val[1] & scope} + for section, body in self._config_spec.items() + } + + def define_int( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, numbers.Integral, scope) + + def define_bool( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, bool, scope) + + def define_str( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, str, scope) + + def define_float( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, numbers.Real, scope) - def _get_spec(self) -> Mapping[str, Any]: - return self._config_spec + def define_int_list( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, (list, numbers.Integral), scope) - def define_int(self, section: str, name: str) -> None: - self._define_param(section, name, numbers.Integral) + def define_bool_list( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, (list, bool), scope) - def define_bool(self, section: str, name: str) -> None: - self._define_param(section, name, bool) + def define_str_list( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, (list, str), scope) - def define_str(self, section: str, name: str) -> None: - self._define_param(section, name, str) + def define_float_list( + self, section: str, name: str, *, scope: ConfigScope = ConfigScope.ALL + ) -> None: + self._define_param(section, name, (list, numbers.Real), scope) - def define_float(self, section: str, name: str) -> None: - self._define_param(section, name, numbers.Real) - def define_int_list(self, section: str, name: str) -> None: - self._define_param(section, name, (list, numbers.Integral)) +@dataclass(frozen=True) +class _VersionRecord: + package: str + update_text: Callable[[str, str], str] + exclusive: bool + delay: float - def define_bool_list(self, section: str, name: str) -> None: - self._define_param(section, name, (list, bool)) - def define_str_list(self, section: str, name: str) -> None: - self._define_param(section, name, (list, str)) +class VersionChecker: + def __init__(self) -> None: + self._records: Dict[str, _VersionRecord] = {} - def define_float_list(self, section: str, name: str) -> None: - self._define_param(section, name, (list, numbers.Real)) + def register( + self, + package: str, + update_text: Callable[[str, str], str], + *, + exclusive: bool = False, + delay: float = 0, + ) -> None: + record = _VersionRecord(package, update_text, exclusive, delay) + if exclusive and any(rec.exclusive for rec in self._records.values()): + pkgs = [rec.package for rec in self._records.values() if rec.exclusive] + raise ConfigError(f"Exclusive record for package {pkgs[0]} already exists") + self._records[package] = record class PluginManager: _config: ConfigBuilder + _version_checker: VersionChecker def __init__(self) -> None: self._config = ConfigBuilder() + self._version_checker = VersionChecker() @property def config(self) -> ConfigBuilder: return self._config + + @property + def version_checker(self) -> VersionChecker: + return self._version_checker diff --git a/neuro-sdk/tests/conftest.py b/neuro-sdk/tests/conftest.py index be8ce363f..e5fadc6bf 100644 --- a/neuro-sdk/tests/conftest.py +++ b/neuro-sdk/tests/conftest.py @@ -10,7 +10,7 @@ from jose import jwt from yarl import URL -from neuro_sdk import Client, Cluster, Preset, __version__ +from neuro_sdk import Client, Cluster, PluginManager, Preset, __version__ from neuro_sdk.config import _AuthConfig, _AuthToken, _ConfigData, _save from neuro_sdk.tracing import _make_trace_config @@ -102,6 +102,7 @@ def go( clusters: Optional[Dict[str, Cluster]] = None, token_url: Optional[URL] = None, admin_url: Optional[URL] = None, + plugin_manager: Optional[PluginManager] = None, ) -> Client: url = URL(url_str) if clusters is None: @@ -165,6 +166,8 @@ def go( real_auth_config = auth_config if admin_url is None: admin_url = URL(url) / ".." / ".." / "apis" / "admin" / "v1" + if plugin_manager is None: + plugin_manager = PluginManager() config = _ConfigData( auth_config=real_auth_config, auth_token=_AuthToken.create_non_expiring(token), @@ -177,6 +180,6 @@ def go( config_dir = tmp_path / ".neuro" _save(config, config_dir) session = aiohttp.ClientSession(trace_configs=[_make_trace_config()]) - return Client._create(session, config_dir, trace_id) + return Client._create(session, config_dir, trace_id, None, plugin_manager) return go diff --git a/neuro-sdk/tests/test_config.py b/neuro-sdk/tests/test_config.py index a397ba904..cbe86600a 100644 --- a/neuro-sdk/tests/test_config.py +++ b/neuro-sdk/tests/test_config.py @@ -10,7 +10,7 @@ from aiohttp import web from yarl import URL -from neuro_sdk import Client, Cluster, ConfigError, Preset +from neuro_sdk import Client, Cluster, ConfigError, ConfigScope, PluginManager, Preset from neuro_sdk.config import _check_sections, _merge_user_configs, _validate_user_config from neuro_sdk.login import _AuthToken @@ -19,6 +19,18 @@ _MakeClient = Callable[..., Client] +@pytest.fixture() +def plugin_manager() -> PluginManager: + manager = PluginManager() + manager.config.define_str("job", "ps-format") + manager.config.define_str("job", "top-format") + manager.config.define_str("job", "life-span") + manager.config.define_str("job", "cluster-name", scope=ConfigScope.LOCAL) + manager.config.define_str_list("storage", "cp-exclude") + manager.config.define_str_list("storage", "cp-exclude-from-files") + return manager + + class TestMergeUserConfigs: def test_empty_dicts(self) -> None: assert _merge_user_configs({}, {}) == {} @@ -51,35 +63,43 @@ def test_section_is_not_dict(self) -> None: ): _check_sections({"a": 1}, {"a"}, "file.cfg") - def test_invalid_alias_name(self) -> None: + def test_invalid_alias_name(self, plugin_manager: PluginManager) -> None: with pytest.raises(ConfigError, match="file.cfg: invalid alias name 0123"): - _validate_user_config({"alias": {"0123": "ls"}}, "file.cfg") + _validate_user_config(plugin_manager, {"alias": {"0123": "ls"}}, "file.cfg") - def test_invalid_alias_type(self) -> None: + def test_invalid_alias_type(self, plugin_manager: PluginManager) -> None: with pytest.raises(ConfigError, match="file.cfg: invalid alias command type"): - _validate_user_config({"alias": {"new-name": True}}, "file.cfg") + _validate_user_config( + plugin_manager, {"alias": {"new-name": True}}, "file.cfg" + ) - def test_extra_session_param(self) -> None: + def test_extra_session_param(self, plugin_manager: PluginManager) -> None: with pytest.raises( ConfigError, match="file.cfg: unknown parameters job.unknown-name" ): - _validate_user_config({"job": {"unknown-name": True}}, "file.cfg") + _validate_user_config( + plugin_manager, {"job": {"unknown-name": True}}, "file.cfg" + ) - def test_invalid_param_type(self) -> None: + def test_invalid_param_type(self, plugin_manager: PluginManager) -> None: with pytest.raises( ConfigError, match="file.cfg: invalid type for job.ps-format, str is expected", ): - _validate_user_config({"job": {"ps-format": True}}, "file.cfg") + _validate_user_config( + plugin_manager, {"job": {"ps-format": True}}, "file.cfg" + ) - def test_invalid_complex_type(self) -> None: + def test_invalid_complex_type(self, plugin_manager: PluginManager) -> None: with pytest.raises( ConfigError, match="file.cfg: invalid type for storage.cp-exclude, list is expected", ): - _validate_user_config({"storage": {"cp-exclude": "abc"}}, "file.cfg") + _validate_user_config( + plugin_manager, {"storage": {"cp-exclude": "abc"}}, "file.cfg" + ) - def test_invalid_complex_item_type(self) -> None: + def test_invalid_complex_item_type(self, plugin_manager: PluginManager) -> None: with pytest.raises( ConfigError, match=( @@ -87,11 +107,15 @@ def test_invalid_complex_item_type(self) -> None: "str is expected" ), ): - _validate_user_config({"storage": {"cp-exclude": [1, 2]}}, "file.cfg") + _validate_user_config( + plugin_manager, {"storage": {"cp-exclude": [1, 2]}}, "file.cfg" + ) - def test_not_allowed_cluster_name(self) -> None: + def test_not_allowed_cluster_name(self, plugin_manager: PluginManager) -> None: with pytest.raises(ConfigError, match=r"file.cfg: cluster name is not allowed"): - _validate_user_config({"job": {"cluster-name": "another"}}, "file.cfg") + _validate_user_config( + plugin_manager, {"job": {"cluster-name": "another"}}, "file.cfg" + ) async def test_get_user_config_empty(make_client: _MakeClient) -> None: @@ -170,8 +194,12 @@ async def test_get_cluster_name_from_local( make_client: _MakeClient, multiple_clusters_config: Dict[str, Cluster], ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "cluster-name", scope=ConfigScope.LOCAL) async with make_client( - "https://example.org", clusters=multiple_clusters_config + "https://example.org", + clusters=multiple_clusters_config, + plugin_manager=plugin_manager, ) as client: proj_dir = tmp_path / "project" local_dir = proj_dir / "folder" @@ -212,7 +240,11 @@ async def test_get_cluster_name_from_local_invalid_cluster( make_client: _MakeClient, multiple_clusters_config: Dict[str, Cluster], ) -> None: - async with make_client("https://example.org") as client: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "cluster-name", scope=ConfigScope.LOCAL) + async with make_client( + "https://example.org", plugin_manager=plugin_manager + ) as client: proj_dir = tmp_path / "project" local_dir = proj_dir / "folder" local_dir.mkdir(parents=True, exist_ok=True) @@ -510,8 +542,12 @@ async def test_switch_clusters_local( make_client: _MakeClient, multiple_clusters_config: Dict[str, Cluster], ) -> None: + plugin_manager = PluginManager() + plugin_manager.config.define_str("job", "cluster-name", scope=ConfigScope.LOCAL) async with make_client( - "https://example.org", clusters=multiple_clusters_config + "https://example.org", + clusters=multiple_clusters_config, + plugin_manager=plugin_manager, ) as client: proj_dir = tmp_path / "project" local_dir = proj_dir / "folder" diff --git a/neuro-cli/tests/unit/test_version_check.py b/neuro-sdk/tests/test_version_check.py similarity index 76% rename from neuro-cli/tests/unit/test_version_check.py rename to neuro-sdk/tests/test_version_check.py index db510e39a..e7669f896 100644 --- a/neuro-cli/tests/unit/test_version_check.py +++ b/neuro-sdk/tests/test_version_check.py @@ -2,7 +2,7 @@ import socket import ssl import time -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Tuple import aiohttp import dateutil.parser @@ -12,10 +12,7 @@ from aiohttp.abc import AbstractResolver from aiohttp.test_utils import unused_port -from neuro_sdk import Client - -from neuro_cli import version_utils -from neuro_cli.root import Root +from neuro_sdk import Client, PluginManager PYPI_JSON = { "info": { @@ -50,7 +47,7 @@ "platform": "", "project_url": "https://pypi.org/project/neuro-cli/", "project_urls": {"Homepage": "https://neu.ro/"}, - "release_url": "https://pypi.org/project/neuro-cli/0.2.1/", + "release_url": "https://pypi.org/project/neuro-cli/50.1.1/", "requires_dist": [ "aiohttp (>=3.0)", "python-jose (>=3.0.0)", @@ -63,7 +60,7 @@ ], "requires_python": ">=3.6.0", "summary": "Neuro Platform API client", - "version": "0.2.1", + "version": "50.1.1", }, "last_serial": 4757285, "releases": { @@ -83,10 +80,10 @@ "requires_python": ">=3.6.0", "size": 47043, "upload_time": "2019-01-28T20:01:21", - "url": "https://files.pytho...ation-0.2.1-py3-none-any.whl", + "url": "https://files.pytho...ation-50.1.1-py3-none-any.whl", } ], - "0.2.1": [ + "50.1.1": [ { "comment_text": "", "digests": { @@ -94,7 +91,7 @@ "sha256": "fd50b1f904c4...af6213c363ec5a83f3168aae1b8", }, "downloads": -1, - "filename": "neuro-cli-0.2.1-py3-none-any.whl", + "filename": "neuro-cli-50.1.1-py3-none-any.whl", "has_sig": False, "md5_digest": "8dd303ee04215ff7f5c2e7f03a6409da", "packagetype": "bdist_wheel", @@ -102,7 +99,7 @@ "requires_python": ">=3.6.0", "size": 48633, "upload_time": "2019-01-29T23:45:22", - "url": "https://files.pytho...ation-0.2.1-py3-none-any.whl", + "url": "https://files.pytho...ation-50.1.1-py3-none-any.whl", }, { "comment_text": "", @@ -111,7 +108,7 @@ "sha256": "046832c04d4e7...38f6514d0e5b9acc4939", }, "downloads": -1, - "filename": "neuro-cli-0.2.1.tar.gz", + "filename": "neuro-cli-50.1.1.tar.gz", "has_sig": False, "md5_digest": "af8fea5f3df6f7f81e9c6cbc6dd7c1e8", "packagetype": "sdist", @@ -119,7 +116,7 @@ "requires_python": None, "size": 156721, "upload_time": "2019-01-30T00:02:23", - "url": "https://files.pytho...ation-0.2.1.tar.gz", + "url": "https://files.pytho...ation-50.1.1.tar.gz", }, ], }, @@ -131,7 +128,7 @@ "sha256": "fd50b1f90c...c5a83f3168aae1b8", }, "downloads": -1, - "filename": "neuro-cli-0.2.1-py3-none-any.whl", + "filename": "neuro-cli-50.1.1-py3-none-any.whl", "has_sig": False, "md5_digest": "8dd303ee04215ff7f5c2e7f03a6409da", "packagetype": "bdist_wheel", @@ -139,7 +136,7 @@ "requires_python": ">=3.6.0", "size": 48633, "upload_time": "2019-01-29T23:45:22", - "url": "https://files.pytho...ation-0.2.1-py3-none-any.whl", + "url": "https://files.pytho...ation-50.1.1-py3-none-any.whl", } ], } @@ -230,17 +227,32 @@ async def fake_pypi( await fake_pypi.stop() +NEURO_CLI_UPGRADE = """\ +You are using Neuro Platform Client {old_ver}, however {new_ver} is available. +You should consider upgrading via the following command: + python -m pip install --upgrade neuro-cli +""" + + +def get_neuro_cli_txt(old: str, new: str) -> str: + return NEURO_CLI_UPGRADE.format(old_ver=old, new_ver=new) + + @pytest.fixture() async def client( - fake_pypi: Tuple[FakePyPI, Dict[str, int]], root: Root + fake_pypi: Tuple[FakePyPI, Dict[str, int]], + make_client: Callable[..., Client], ) -> AsyncIterator[Client]: resolver = FakeResolver(fake_pypi[1]) connector = aiohttp.TCPConnector(resolver=resolver, ssl=False, keepalive_timeout=0) - old_session = root.client._session - async with aiohttp.ClientSession(connector=connector) as session: - root.client._session = session - yield root.client - root.client._session = old_session + plugin_manager = PluginManager() + plugin_manager.version_checker.register("neuro-cli", get_neuro_cli_txt) + client = make_client("http://example.com", plugin_manager=plugin_manager) + client._session = aiohttp.ClientSession(connector=connector) + client._core._session = client._session + yield client + await client.close() + await asyncio.sleep(0.5) # can be removed for aiohttp 4.0 @pytest.fixture @@ -248,52 +260,43 @@ def pypi_server(fake_pypi: Tuple[FakePyPI, Dict[str, int]]) -> FakePyPI: return fake_pypi[0] -async def test__fetch_pypi(pypi_server: FakePyPI, client: Client) -> None: +async def test_update(pypi_server: FakePyPI, client: Client) -> None: pypi_server.response = (200, PYPI_JSON) t0 = time.time() - record = await version_utils._fetch_package(client._session, "neuro-cli") - assert record is not None - assert record["version"] == "0.2.1" + await client.version_checker.update() + assert len(client.version_checker._records) == 1 + record = client.version_checker._records["neuro-cli"] + assert record["package"] == "neuro-cli" + assert record["version"] == "50.1.1" assert ( record["uploaded"] == dateutil.parser.parse("2019-01-30T00:02:23").timestamp() ) assert t0 <= record["checked"] <= time.time() + with client.config._open_db() as db: + ret = list(db.execute("SELECT package, version FROM pypi")) + assert len(ret) == 1 + assert list(ret[0]) == ["neuro-cli", "50.1.1"] -async def test__fetch_pypi_no_releases(pypi_server: FakePyPI, client: Client) -> None: + +async def test_update_no_releases(pypi_server: FakePyPI, client: Client) -> None: pypi_server.response = (200, {}) - record = await version_utils._fetch_package(client._session, "neuro-cli") - assert record is None + await client.version_checker.update() + assert not client.version_checker._records -async def test__fetch_pypi_non_200(pypi_server: FakePyPI, client: Client) -> None: +async def test_update_non_200(pypi_server: FakePyPI, client: Client) -> None: pypi_server.response = (403, {"Status": "Forbidden"}) - record = await version_utils._fetch_package(client._session, "neuro-cli") - assert record is None - + await client.version_checker.update() + assert not client.version_checker._records -async def test_run_version_checker(pypi_server: FakePyPI, client: Client) -> None: - pypi_server.response = (200, PYPI_JSON) - - await version_utils.run_version_checker(client, False) - with client.config._open_db() as db: - ret = list(db.execute("SELECT package, version FROM pypi")) - assert len(ret) == 1 - assert list(ret[0]) == ["neuro-cli", "0.2.1"] - -async def test_run_version_checker_disabled( - pypi_server: FakePyPI, client: Client -) -> None: +async def test_get_outdated(pypi_server: FakePyPI, client: Client) -> None: pypi_server.response = (200, PYPI_JSON) - with client.config._open_db() as db: - version_utils._ensure_schema(db) - - await version_utils.run_version_checker(client, True) - with client.config._open_db() as db: - ret = list(db.execute("SELECT package, version FROM pypi")) - assert len(ret) == 0 + await client.version_checker.update() + outdated = await client.version_checker.get_outdated() + assert "neuro-cli" in outdated.keys()