diff --git a/src/snowflake/cli/_plugins/snowpark/commands.py b/src/snowflake/cli/_plugins/snowpark/commands.py index 3b66657e8..d0e471376 100644 --- a/src/snowflake/cli/_plugins/snowpark/commands.py +++ b/src/snowflake/cli/_plugins/snowpark/commands.py @@ -35,7 +35,6 @@ from snowflake.cli._plugins.object.manager import ObjectManager from snowflake.cli._plugins.snowpark import package_utils from snowflake.cli._plugins.snowpark.common import ( - FunctionOrProcedure, UdfSprocIdentifier, check_if_replace_is_required, ) @@ -54,7 +53,10 @@ ) from snowflake.cli._plugins.snowpark.zipper import zip_dir from snowflake.cli._plugins.stage.manager import StageManager -from snowflake.cli.api.cli_global_context import get_cli_context +from snowflake.cli.api.cli_global_context import ( + _CliGlobalContextAccess, + get_cli_context, +) from snowflake.cli.api.commands.decorators import ( with_project_definition, ) @@ -70,7 +72,11 @@ DEPLOYMENT_STAGE, ObjectType, ) -from snowflake.cli.api.exceptions import SecretsWithoutExternalAccessIntegrationError +from snowflake.cli.api.entities.snowpark_entity import SnowparkEntity +from snowflake.cli.api.exceptions import ( + NoProjectDefinitionError, + SecretsWithoutExternalAccessIntegrationError, +) from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.output.types import ( CollectionResult, @@ -78,7 +84,10 @@ MessageResult, SingleQueryResult, ) -from snowflake.cli.api.project.project_verification import assert_project_type +from snowflake.cli.api.project.schemas.project_definition import ( + ProjectDefinition, + ProjectDefinitionV2, +) from snowflake.cli.api.project.schemas.snowpark.callable import ( FunctionSchema, ProcedureSchema, @@ -121,18 +130,15 @@ def deploy( By default, if any of the objects exist already the commands will fail unless `--replace` flag is provided. All deployed objects use the same artifact which is deployed only once. """ + cli_context, pd = _get_v2_context_and_project_definition() - assert_project_type("snowpark") - - cli_context = get_cli_context() - snowpark = cli_context.project_definition.snowpark paths = SnowparkPackagePaths.for_snowpark_project( project_root=SecurePath(cli_context.project_root), - snowpark_project_definition=snowpark, + project_definition=pd, ) - procedures = snowpark.procedures - functions = snowpark.functions + procedures = pd.get_entities_by_type("procedure") + functions = pd.get_entities_by_type("function") if not procedures and not functions: raise ClickException( @@ -164,29 +170,33 @@ def deploy( raise ClickException(msg) # Create stage - stage_name = snowpark.stage_name - stage_manager = StageManager() - stage_name = FQN.from_string(stage_name).using_context() - stage_manager.create(fqn=stage_name, comment="deployments managed by Snowflake CLI") - snowflake_dependencies = _read_snowflake_requrements_file( paths.snowflake_requirements_file ) + stage_names = { + entity.stage for entity in [*functions.values(), *procedures.values()] + } + stage_manager = StageManager() - artifact_stage_directory = get_app_stage_path(stage_name, snowpark.project_name) - artifact_stage_target = ( - f"{artifact_stage_directory}/{paths.artifact_file.path.name}" - ) + # TODO: Raise error if stage name is not provided - stage_manager.put( - local_path=paths.artifact_file.path, - stage_path=artifact_stage_directory, - overwrite=True, - ) + for stage in stage_names: + stage = FQN.from_string(stage).using_context() + stage_manager.create(fqn=stage, comment="deployments managed by Snowflake CLI") + artifact_stage_directory = get_app_stage_path(stage, pd.defaults.project_name) + artifact_stage_target = ( + f"{artifact_stage_directory}/{paths.artifact_file.path.name}" + ) + + stage_manager.put( + local_path=paths.artifact_file.path, + stage_path=artifact_stage_directory, + overwrite=True, + ) deploy_status = [] # Procedures - for procedure in procedures: + for procedure in procedures.values(): operation_result = _deploy_single_object( manager=pm, object_type=ObjectType.PROCEDURE, @@ -198,7 +208,7 @@ def deploy( deploy_status.append(operation_result) # Functions - for function in functions: + for function in functions.values(): operation_result = _deploy_single_object( manager=fm, object_type=ObjectType.FUNCTION, @@ -213,9 +223,9 @@ def deploy( def _assert_object_definitions_are_correct( - object_type, object_definitions: List[FunctionOrProcedure] + object_type, object_definitions: Dict[str, SnowparkEntity] ): - for definition in object_definitions: + for name, definition in object_definitions.items(): database = definition.database schema = definition.schema_name name = definition.name @@ -232,11 +242,11 @@ def _assert_object_definitions_are_correct( def _find_existing_objects( object_type: ObjectType, - objects: List[FunctionOrProcedure], + objects: Dict[str, SnowparkEntity], om: ObjectManager, ): existing_objects = {} - for object_definition in objects: + for object_name, object_definition in objects.items(): identifier = UdfSprocIdentifier.from_definition( object_definition ).identifier_with_arg_types @@ -253,8 +263,8 @@ def _find_existing_objects( def _check_if_all_defined_integrations_exists( om: ObjectManager, - functions: List[FunctionSchema], - procedures: List[ProcedureSchema], + functions: Dict[str, FunctionSchema], + procedures: Dict[str, ProcedureSchema], ): existing_integrations = { i["name"].lower() @@ -262,7 +272,7 @@ def _check_if_all_defined_integrations_exists( if i["type"] == "EXTERNAL_ACCESS" } declared_integration: Set[str] = set() - for object_definition in [*functions, *procedures]: + for object_definition in [*functions.values(), *procedures.values()]: external_access_integrations = { s.lower() for s in object_definition.external_access_integrations } @@ -280,7 +290,7 @@ def _check_if_all_defined_integrations_exists( ) -def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str: +def get_app_stage_path(stage_name: Optional[str | FQN], project_name: str) -> str: artifact_stage_directory = f"@{(stage_name or DEPLOYMENT_STAGE)}/{project_name}" return artifact_stage_directory @@ -288,7 +298,7 @@ def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str: def _deploy_single_object( manager: FunctionManager | ProcedureManager, object_type: ObjectType, - object_definition: FunctionOrProcedure, + object_definition: SnowparkEntity, existing_objects: Dict[str, Dict], snowflake_dependencies: List[str], stage_artifact_path: str, @@ -374,16 +384,16 @@ def build( ) -> CommandResult: """ Builds the Snowpark project as a `.zip` archive that can be used by `deploy` command. - The archive is built using only the `src` directory specified in the project file. + The archive is built using only the `artifacts` directory specified in the project file. """ + cli_context, pd = _get_v2_context_and_project_definition() - assert_project_type("snowpark") - cli_context = get_cli_context() snowpark_paths = SnowparkPackagePaths.for_snowpark_project( project_root=SecurePath(cli_context.project_root), - snowpark_project_definition=cli_context.project_definition.snowpark, + project_definition=pd, ) - log.info("Building package using sources from: %s", snowpark_paths.source.path) + log.info("Building package using sources from:") + log.info(",".join(str(s) for s in snowpark_paths.sources)) anaconda_packages_manager = AnacondaPackagesManager() @@ -424,7 +434,7 @@ def build( ) zip_dir( - source=snowpark_paths.source.path, + source=snowpark_paths.sources_paths, dest_zip=snowpark_paths.artifact_file.path, ) if any(packages_dir.iterdir()): @@ -510,3 +520,52 @@ def describe( ): """Provides description of a procedure or function.""" object_describe(object_type=object_type.value, object_name=identifier, **options) + + +def _migrate_v1_snowpark_to_v2(pd: ProjectDefinition): + if not pd.snowpark: + raise NoProjectDefinitionError( + project_type="snowpark", project_file=get_cli_context().project_root + ) + + data: dict = { + "definition_version": "2", + "defaults": { + "stage": pd.snowpark.stage_name, + "project_name": pd.snowpark.project_name, + }, + "entities": {}, + } + + for entity in [*pd.snowpark.procedures, *pd.snowpark.functions]: + v2_entity = { + "type": "function" if isinstance(entity, FunctionSchema) else "procedure", + "stage": pd.snowpark.stage_name, + "artifacts": pd.snowpark.src, + "handler": entity.handler, + "returns": entity.returns, + "signature": entity.signature, + "runtime": entity.runtime, + "external_access_integrations": entity.external_access_integrations, + "secrets": entity.secrets, + "imports": entity.imports, + "name": entity.name, + "database": entity.database, + "schema": entity.schema_name, + } + if isinstance(entity, ProcedureSchema): + v2_entity["execute_as_caller"] = entity.execute_as_caller + + data["entities"][entity.name] = v2_entity + + return ProjectDefinitionV2(**data) + + +def _get_v2_context_and_project_definition() -> Tuple[ + _CliGlobalContextAccess, ProjectDefinitionV2 +]: + cli_context = get_cli_context() + pd = cli_context.project_definition + if not pd.meets_version_requirement("2"): + pd = _migrate_v1_snowpark_to_v2(pd) + return cli_context, pd diff --git a/src/snowflake/cli/_plugins/snowpark/common.py b/src/snowflake/cli/_plugins/snowpark/common.py index 3e45f5694..e5116119c 100644 --- a/src/snowflake/cli/_plugins/snowpark/common.py +++ b/src/snowflake/cli/_plugins/snowpark/common.py @@ -15,23 +15,19 @@ from __future__ import annotations import re -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set from snowflake.cli._plugins.snowpark.models import Requirement from snowflake.cli._plugins.snowpark.package_utils import ( generate_deploy_stage_name, ) from snowflake.cli.api.constants import ObjectType +from snowflake.cli.api.entities.snowpark_entity import SnowparkEntity from snowflake.cli.api.identifiers import FQN -from snowflake.cli.api.project.schemas.snowpark.callable import ( - FunctionSchema, - ProcedureSchema, -) from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector.cursor import SnowflakeCursor DEFAULT_RUNTIME = "3.10" -FunctionOrProcedure = Union[FunctionSchema, ProcedureSchema] def check_if_replace_is_required( @@ -271,7 +267,7 @@ def identifier_for_sql(self): return self._identifier_from_signature(self._full_signature(), for_sql=True) @classmethod - def from_definition(cls, udf_sproc: FunctionOrProcedure): + def from_definition(cls, udf_sproc: SnowparkEntity): names = [] types = [] defaults = [] diff --git a/src/snowflake/cli/_plugins/snowpark/package_utils.py b/src/snowflake/cli/_plugins/snowpark/package_utils.py index 2f4b4c9f5..70c9d261e 100644 --- a/src/snowflake/cli/_plugins/snowpark/package_utils.py +++ b/src/snowflake/cli/_plugins/snowpark/package_utils.py @@ -61,7 +61,7 @@ def parse_requirements( ).splitlines(): line = re.sub(r"\s*#.*", "", line).strip() if line: - reqs.append(Requirement.parse(line)) + reqs.append(Requirement.parse_line(line)) return reqs diff --git a/src/snowflake/cli/_plugins/snowpark/snowpark_package_paths.py b/src/snowflake/cli/_plugins/snowpark/snowpark_package_paths.py index 87a9a863f..321db4790 100644 --- a/src/snowflake/cli/_plugins/snowpark/snowpark_package_paths.py +++ b/src/snowflake/cli/_plugins/snowpark/snowpark_package_paths.py @@ -13,8 +13,10 @@ # limitations under the License. from dataclasses import dataclass +from pathlib import Path +from typing import List -from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark +from snowflake.cli.api.project.schemas.project_definition import DefinitionV20 from snowflake.cli.api.secure_path import SecurePath _DEFINED_REQUIREMENTS = "requirements.txt" @@ -23,24 +25,31 @@ @dataclass class SnowparkPackagePaths: - source: SecurePath + sources: List[SecurePath] artifact_file: SecurePath defined_requirements_file: SecurePath = SecurePath(_DEFINED_REQUIREMENTS) snowflake_requirements_file: SecurePath = SecurePath(_REQUIREMENTS_SNOWFLAKE) @classmethod def for_snowpark_project( - cls, project_root: SecurePath, snowpark_project_definition: Snowpark + cls, project_root: SecurePath, project_definition: DefinitionV20 ) -> "SnowparkPackagePaths": - defined_source_path = SecurePath(snowpark_project_definition.src) + sources = set() + entities = project_definition.get_entities_by_type( + "function" + ) | project_definition.get_entities_by_type("procedure") + for name, entity in entities.items(): + sources.add(entity.artifacts) + return cls( - source=cls._get_snowpark_project_source_absolute_path( - project_root=project_root, - defined_source_path=defined_source_path, - ), + sources=[ + cls._get_snowpark_project_source_absolute_path( + project_root, SecurePath(source) + ) + for source in sources + ], artifact_file=cls._get_snowpark_project_artifact_absolute_path( project_root=project_root, - defined_source_path=defined_source_path, ), defined_requirements_file=project_root / _DEFINED_REQUIREMENTS, snowflake_requirements_file=project_root / _REQUIREMENTS_SNOWFLAKE, @@ -56,10 +65,12 @@ def _get_snowpark_project_source_absolute_path( @classmethod def _get_snowpark_project_artifact_absolute_path( - cls, project_root: SecurePath, defined_source_path: SecurePath + cls, project_root: SecurePath ) -> SecurePath: - source_path = cls._get_snowpark_project_source_absolute_path( - project_root=project_root, defined_source_path=defined_source_path - ) - artifact_file = project_root / (source_path.path.name + ".zip") + + artifact_file = project_root / "app.zip" return artifact_file + + @property + def sources_paths(self) -> List[Path]: + return [source.path for source in self.sources] diff --git a/src/snowflake/cli/_plugins/snowpark/zipper.py b/src/snowflake/cli/_plugins/snowpark/zipper.py index 003a88337..abcb45722 100644 --- a/src/snowflake/cli/_plugins/snowpark/zipper.py +++ b/src/snowflake/cli/_plugins/snowpark/zipper.py @@ -17,7 +17,7 @@ import fnmatch import logging from pathlib import Path -from typing import Iterator, Literal +from typing import Dict, List, Literal from zipfile import ZIP_DEFLATED, ZipFile log = logging.getLogger(__name__) @@ -59,16 +59,24 @@ def add_file_to_existing_zip(zip_file: str, file: str): def zip_dir( - source: Path, dest_zip: Path, mode: Literal["r", "w", "x", "a"] = "w" + source: Path | List[Path], + dest_zip: Path, + mode: Literal["r", "w", "x", "a"] = "w", ) -> None: - files_to_pack: Iterator[Path] = filter( - _to_be_zipped, map(lambda f: f.absolute(), source.glob("**/*")) - ) + + if isinstance(source, Path): + source = [source] + + files_to_pack: Dict[Path, List[Path]] = { + src: list(filter(_to_be_zipped, (f.absolute() for f in src.glob("**/*")))) + for src in source + } with ZipFile(dest_zip, mode, ZIP_DEFLATED, allowZip64=True) as package_zip: - for file in files_to_pack: - log.debug("Adding %s to %s", file, dest_zip) - package_zip.write(file, arcname=file.relative_to(source.absolute())) + for src, files in files_to_pack.items(): + for file in files: + log.debug("Adding %s to %s", file, dest_zip) + package_zip.write(file, arcname=file.relative_to(src)) def _to_be_zipped(file: Path) -> bool: diff --git a/src/snowflake/cli/api/entities/snowpark_entity.py b/src/snowflake/cli/api/entities/snowpark_entity.py new file mode 100644 index 000000000..552f0746d --- /dev/null +++ b/src/snowflake/cli/api/entities/snowpark_entity.py @@ -0,0 +1,21 @@ +from snowflake.cli.api.entities.common import EntityBase + + +class SnowparkEntity(EntityBase): + pass + + +class FunctionEntity(SnowparkEntity): + """ + A single UDF + """ + + pass + + +class ProcedureEntity(SnowparkEntity): + """ + A stored procedure + """ + + pass diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 634e5f10e..6d3362831 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -54,6 +54,11 @@ class DefaultsField(UpdatableModel): default=None, ) + project_name: Optional[str] = Field( + title="Name of the project.", + default="my_project", + ) + class EntityModelBase(ABC, UpdatableModel): @classmethod diff --git a/src/snowflake/cli/api/project/schemas/entities/entities.py b/src/snowflake/cli/api/project/schemas/entities/entities.py index 07f0b8a89..1bc42654b 100644 --- a/src/snowflake/cli/api/project/schemas/entities/entities.py +++ b/src/snowflake/cli/api/project/schemas/entities/entities.py @@ -20,6 +20,7 @@ from snowflake.cli.api.entities.application_package_entity import ( ApplicationPackageEntity, ) +from snowflake.cli.api.entities.snowpark_entity import FunctionEntity, ProcedureEntity from snowflake.cli.api.entities.streamlit_entity import StreamlitEntity from snowflake.cli.api.project.schemas.entities.application_entity_model import ( ApplicationEntityModel, @@ -27,13 +28,27 @@ from snowflake.cli.api.project.schemas.entities.application_package_entity_model import ( ApplicationPackageEntityModel, ) +from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( + FunctionEntityModel, + ProcedureEntityModel, +) from snowflake.cli.api.project.schemas.entities.streamlit_entity_model import ( StreamlitEntityModel, ) -Entity = Union[ApplicationEntity, ApplicationPackageEntity, StreamlitEntity] +Entity = Union[ + ApplicationEntity, + ApplicationPackageEntity, + StreamlitEntity, + ProcedureEntity, + FunctionEntity, +] EntityModel = Union[ - ApplicationEntityModel, ApplicationPackageEntityModel, StreamlitEntityModel + ApplicationEntityModel, + ApplicationPackageEntityModel, + StreamlitEntityModel, + FunctionEntityModel, + ProcedureEntityModel, ] ALL_ENTITIES: List[Entity] = [*get_args(Entity)] @@ -44,4 +59,6 @@ ApplicationEntityModel: ApplicationEntity, ApplicationPackageEntityModel: ApplicationPackageEntity, StreamlitEntityModel: StreamlitEntity, + FunctionEntityModel: FunctionEntity, + ProcedureEntityModel: ProcedureEntity, } diff --git a/src/snowflake/cli/api/project/schemas/entities/snowpark_entity.py b/src/snowflake/cli/api/project/schemas/entities/snowpark_entity.py new file mode 100644 index 000000000..978861a87 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/entities/snowpark_entity.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Dict, List, Literal, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.entities.common import EntityModelBase +from snowflake.cli.api.project.schemas.identifier_model import ObjectIdentifierModel +from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.api.project.schemas.updatable_model import DiscriminatorField + + +class SnowparkEntityModel(EntityModelBase): + handler: str = Field( + title="Function’s or procedure’s implementation of the object inside source module", + examples=["functions.hello_function"], + ) + returns: str = Field( + title="Type of the result" + ) # TODO: again, consider Literal/Enum + signature: Union[str, List[Argument]] = Field( + title="The signature parameter describes consecutive arguments passed to the object" + ) + runtime: Optional[Union[str, float]] = Field( + title="Python version to use when executing ", default=None + ) + external_access_integrations: Optional[List[str]] = Field( + title="Names of external access integrations needed for this procedure’s handler code to access external networks", + default=[], + ) + secrets: Optional[Dict[str, str]] = Field( + title="Assigns the names of secrets to variables so that you can use the variables to reference the secrets", + default={}, + ) + imports: Optional[List[str]] = Field( + title="Stage and path to previously uploaded files you want to import", + default=[], + ) + stage: str = Field(title="Stage in which artifacts will be stored") + artifacts: str = Field(title="Folder where your code should be located") + + @field_validator("runtime") + @classmethod + def convert_runtime(cls, runtime_input: Union[str, float]) -> str: + if isinstance(runtime_input, float): + return str(runtime_input) + return runtime_input + + +class ProcedureEntityModel(SnowparkEntityModel, ObjectIdentifierModel("procedure")): # type: ignore + type: Literal["procedure"] = DiscriminatorField() # noqa: A003 + execute_as_caller: Optional[bool] = Field( + title="Determine whether the procedure is executed with the privileges of " + "the owner (you) or with the privileges of the caller", + default=False, + ) + + +class FunctionEntityModel(SnowparkEntityModel, ObjectIdentifierModel("function")): # type: ignore + type: Literal["function"] = DiscriminatorField() # noqa: A003 diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index c72f96e90..ce74e3873 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -42,6 +42,8 @@ from snowflake.cli.api.utils.types import Context from typing_extensions import Annotated +AnnotatedEntity = Annotated[EntityModel, Field(discriminator="type")] + @dataclass class ProjectProperties: @@ -111,9 +113,7 @@ class DefinitionV11(DefinitionV10): class DefinitionV20(_ProjectDefinitionBase): - entities: Dict[str, Annotated[EntityModel, Field(discriminator="type")]] = Field( - title="Entity definitions." - ) + entities: Dict[str, AnnotatedEntity] = Field(title="Entity definitions.") @model_validator(mode="before") @classmethod @@ -148,18 +148,28 @@ def validate_entities_identifiers( @field_validator("entities", mode="after") @classmethod def validate_entities( - cls, entities: Dict[str, EntityModel] - ) -> Dict[str, EntityModel]: + cls, entities: Dict[str, AnnotatedEntity] + ) -> Dict[str, AnnotatedEntity]: for key, entity in entities.items(): # TODO Automatically detect TargetFields to validate - if entity.type == ApplicationEntityModel.get_type(): - if isinstance(entity.from_, TargetField): - target_key = entity.from_.target - target_object = entity.from_ - target_type = target_object.get_type() - cls._validate_target_field(target_key, target_type, entities) + if isinstance(entity, list): + for e in entity: + cls._validate_single_entity(e, entities) + else: + cls._validate_single_entity(entity, entities) return entities + @classmethod + def _validate_single_entity( + cls, entity: EntityModel, entities: Dict[str, AnnotatedEntity] + ): + if entity.type == ApplicationEntityModel.get_type(): + if isinstance(entity.from_, TargetField): + target_key = entity.from_.target + target_object = entity.from_ + target_type = target_object.get_type() + cls._validate_target_field(target_key, target_type, entities) + @classmethod def _validate_target_field( cls, target_key: str, target_type: EntityModel, entities: Dict[str, EntityModel] diff --git a/tests/__snapshots__/test_connection.ambr b/tests/__snapshots__/test_connection.ambr index bb35fbbd7..e68847957 100644 --- a/tests/__snapshots__/test_connection.ambr +++ b/tests/__snapshots__/test_connection.ambr @@ -1,76 +1,4 @@ # serializer version: 1 -# name: test_connection_can_be_added_with_existing_paths_in_arguments[-k] - ''' - Snowflake password [optional]: - Role for the connection [optional]: - Warehouse for the connection [optional]: - Database for the connection [optional]: - Schema for the connection [optional]: - Connection host [optional]: - Snowflake region [optional]: - Authentication method [optional]: - Path to token file [optional]: - Wrote new connection conn1 to /Users/jsikorski/.snowflake/config.toml - - ''' -# --- -# name: test_connection_can_be_added_with_existing_paths_in_arguments[-t] - ''' - Snowflake password [optional]: - Role for the connection [optional]: - Warehouse for the connection [optional]: - Database for the connection [optional]: - Schema for the connection [optional]: - Connection host [optional]: - Snowflake region [optional]: - Authentication method [optional]: - Path to private key file [optional]: - Wrote new connection conn1 to /Users/jsikorski/.snowflake/config.toml - - ''' -# --- -# name: test_connection_can_be_added_with_existing_paths_in_prompt[10] - ''' - [connections.connName] - account = "accName" - user = "userName" - password = "password" - token_file_path = "/var/folders/k8/3sdqh3nn4gg7lpr5fz0fjlqw0000gn/T/tmpjbd8o_i2" - - ''' -# --- -# name: test_connection_can_be_added_with_existing_paths_in_prompt[9] - ''' - [connections.connName] - account = "accName" - user = "userName" - password = "password" - private_key_path = "/var/folders/k8/3sdqh3nn4gg7lpr5fz0fjlqw0000gn/T/tmp0rnw_ay8" - - ''' -# --- -# name: test_file_paths_have_to_exist_when_given_in_arguments[-k] - ''' - +- Error ----------------------------------------------------------------------+ - | Path ~/path/to/file does not exist. | - +------------------------------------------------------------------------------+ - - ''' -# --- -# name: test_file_paths_have_to_exist_when_given_in_arguments[-t] - ''' - +- Error ----------------------------------------------------------------------+ - | Path ~/path/to/file does not exist. | - +------------------------------------------------------------------------------+ - - ''' -# --- -# name: test_file_paths_have_to_exist_when_given_in_prompt[10] - '' -# --- -# name: test_file_paths_have_to_exist_when_given_in_prompt[9] - '' -# --- # name: test_if_whitespaces_are_stripped_from_connection_name ''' [connections.whitespaceTest] diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 1dd0a2778..272e49b9b 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -2904,8 +2904,8 @@ Usage: default snowpark build [OPTIONS] Builds the Snowpark project as a `.zip` archive that can be used by `deploy` - command. The archive is built using only the `src` directory specified in the - project file. + command. The archive is built using only the `artifacts` directory specified + in the project file. +- Options --------------------------------------------------------------------+ | --ignore-anaconda Does not lookup packages on | @@ -3646,8 +3646,8 @@ +------------------------------------------------------------------------------+ +- Commands -------------------------------------------------------------------+ | build Builds the Snowpark project as a `.zip` archive that can be used | - | by `deploy` command. The archive is built using only the `src` | - | directory specified in the project file. | + | by `deploy` command. The archive is built using only the | + | `artifacts` directory specified in the project file. | | deploy Deploys procedures and functions defined in project. Deploying | | the project alters all objects defined in it. By default, if any | | of the objects exist already the commands will fail unless | @@ -8083,8 +8083,8 @@ +------------------------------------------------------------------------------+ +- Commands -------------------------------------------------------------------+ | build Builds the Snowpark project as a `.zip` archive that can be used | - | by `deploy` command. The archive is built using only the `src` | - | directory specified in the project file. | + | by `deploy` command. The archive is built using only the | + | `artifacts` directory specified in the project file. | | deploy Deploys procedures and functions defined in project. Deploying | | the project alters all objects defined in it. By default, if any | | of the objects exist already the commands will fail unless | diff --git a/tests/project/test_project_definition_v2.py b/tests/project/test_project_definition_v2.py index 133cf2897..33312bc62 100644 --- a/tests/project/test_project_definition_v2.py +++ b/tests/project/test_project_definition_v2.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import pytest +from snowflake.cli._plugins.snowpark.commands import _migrate_v1_snowpark_to_v2 +from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.schemas.entities.entities import ( ALL_ENTITIES, @@ -21,9 +22,13 @@ v2_entity_model_to_entity_map, v2_entity_model_types_map, ) +from snowflake.cli.api.project.schemas.entities.snowpark_entity import ( + SnowparkEntityModel, +) from snowflake.cli.api.project.schemas.project_definition import ( DefinitionV20, ) +from snowflake.cli.api.project.schemas.snowpark.callable import _CallableBase from tests.testing_utils.mock_config import mock_config_key @@ -162,6 +167,68 @@ }, None, ], + # Snowpark fields + [ + { + "defaults": {"stage": "dev"}, + "entities": { + "function1": { + "type": "function", + "name": "name", + "handler": "app.hello", + "returns": "string", + "signature": [{"name": "name", "type": "string"}], + "runtime": "3.10", + "artifacts": "src", + } + }, + }, + None, + ], + [ + { + "defaults": {"stage": "dev", "project_name": "my_project"}, + "entities": { + "procedure1": { + "type": "procedure", + "name": "name", + "handler": "app.hello", + "returns": "string", + "signature": [{"name": "name", "type": "string"}], + "runtime": "3.10", + "artifacts": "src", + "execute_as_caller": True, + } + }, + }, + None, + ], + [ + { + "defaults": {"stage": "dev", "project_name": "my_project"}, + "entities": { + "procedure1": { + "type": "procedure", + "handler": "app.hello", + "returns": "string", + "signature": [{"name": "name", "type": "string"}], + "runtime": "3.10", + "artifacts": "src", + "execute_as_caller": True, + } + }, + }, + [ + "Your project definition is missing the following field: 'entities.procedure1.procedure.name'", + ], + ], + [ + {"entities": {"function1": {"type": "function", "handler": "app.hello"}}}, + [ + "Your project definition is missing the following field: 'entities.function1.function.returns'", + "Your project definition is missing the following field: 'entities.function1.function.signature'", + ], + ], ], ) def test_project_definition_v2_schema(definition_input, expected_error): @@ -276,3 +343,53 @@ def test_entity_model_to_entity_map(): entity_models.remove(entity_model_class) assert len(entities) == 0 assert len(entity_models) == 0 + + +@pytest.mark.parametrize( + "project_name", + [ + "snowpark_functions", + "snowpark_procedures", + "snowpark_procedures_coverage", + "snowpark_function_fully_qualified_name", + ], +) +def test_v1_to_v2_conversion( + project_directory, project_name: str +): # project_name: str, expected_values: Dict[str, Any]): + + with project_directory(project_name) as project_dir: + definition_v1 = DefinitionManager(project_dir).project_definition + definition_v2 = _migrate_v1_snowpark_to_v2(definition_v1) + assert definition_v2.definition_version == "2" + assert ( + definition_v1.snowpark.project_name == definition_v2.defaults.project_name + ) + assert len(definition_v1.snowpark.procedures) == len( + definition_v2.get_entities_by_type("procedure") + ) + assert len(definition_v1.snowpark.functions) == len( + definition_v2.get_entities_by_type("function") + ) + + for v1_procedure in definition_v1.snowpark.procedures: + v2_procedure = definition_v2.entities.get(v1_procedure.name) + assert v2_procedure + assert v2_procedure.artifacts == definition_v1.snowpark.src + assert _compare_entity(v1_procedure, v2_procedure) + + for v1_function in definition_v1.snowpark.functions: + v2_function = definition_v2.entities.get(v1_function.name) + assert v2_function + assert v2_function.artifacts == definition_v1.snowpark.src + assert _compare_entity(v1_function, v2_function) + + +def _compare_entity(v1_entity: _CallableBase, v2_entity: SnowparkEntityModel) -> bool: + return ( + v1_entity.name == v2_entity.name + and v1_entity.handler == v2_entity.handler + and v1_entity.returns == v2_entity.returns + and v1_entity.signature == v2_entity.signature + and v1_entity.runtime == v2_entity.runtime + ) diff --git a/tests_integration/test_snowpark.py b/tests_integration/test_snowpark.py index e256eb01c..f71c5de8c 100644 --- a/tests_integration/test_snowpark.py +++ b/tests_integration/test_snowpark.py @@ -20,6 +20,7 @@ import pytest + from tests_integration.testing_utils import ( SnowparkTestSteps, ) @@ -860,10 +861,10 @@ def test_incorrect_requirements(project_directory, runner, alter_requirements_tx ) with pytest.raises(InvalidRequirement) as err: runner.invoke_with_connection(["snowpark", "build"]) - assert ( - "Expected end or semicolon (after name and no valid version specifier)" - in str(err) - ) + assert ( + "Expected end or semicolon (after name and no valid version specifier)" + in str(err) + ) @pytest.fixture