diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 092b1f563..7b9279410 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -40,6 +40,7 @@ but should be replaced with `snow init` * Added connection option `--token-file-path` allowing passing OAuth token using a file. The function is also supported by setting `token_file_path` in connection definition. * Support for Python remote execution via `snow stage execute` and `snow git execute` similar to existing EXECUTE IMMEDIATE support. +* Added support for project definition file defaults in templates * Added support for autocomplete in `--connection` flag. * Added `snow init` command, which supports initializing projects with external templates. diff --git a/src/snowflake/cli/api/project/schemas/entities/application_entity.py b/src/snowflake/cli/api/project/schemas/entities/application_entity.py index 983e9b15d..9c2c7bea6 100644 --- a/src/snowflake/cli/api/project/schemas/entities/application_entity.py +++ b/src/snowflake/cli/api/project/schemas/entities/application_entity.py @@ -16,7 +16,7 @@ from typing import Literal, Optional -from pydantic import AliasChoices, Field +from pydantic import Field from snowflake.cli.api.project.schemas.entities.application_package_entity import ( ApplicationPackageEntity, ) @@ -25,26 +25,20 @@ TargetField, ) from snowflake.cli.api.project.schemas.updatable_model import ( - UpdatableModel, + DiscriminatorField, ) class ApplicationEntity(EntityBase): - type: Literal["application"] # noqa: A003 + type: Literal["application"] = DiscriminatorField() # noqa A003 name: str = Field( title="Name of the application created when this entity is deployed" ) - from_: ApplicationFromField = Field( - validation_alias=AliasChoices("from"), + from_: TargetField[ApplicationPackageEntity] = Field( + alias="from", title="An application package this entity should be created from", ) debug: Optional[bool] = Field( title="Whether to enable debug mode when using a named stage to create an application object", default=None, ) - - -class ApplicationFromField(UpdatableModel): - target: TargetField[ApplicationPackageEntity] = Field( - title="Reference to an application package entity", - ) diff --git a/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py b/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py index e86617bcd..e8cd24e2c 100644 --- a/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py +++ b/src/snowflake/cli/api/project/schemas/entities/application_package_entity.py @@ -22,11 +22,14 @@ ) from snowflake.cli.api.project.schemas.native_app.package import DistributionOptions from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping -from snowflake.cli.api.project.schemas.updatable_model import IdentifierField +from snowflake.cli.api.project.schemas.updatable_model import ( + DiscriminatorField, + IdentifierField, +) class ApplicationPackageEntity(EntityBase): - type: Literal["application package"] # noqa: A003 + type: Literal["application package"] = DiscriminatorField() # noqa: A003 name: str = Field( title="Name of the application package created when this entity is deployed" ) diff --git a/src/snowflake/cli/api/project/schemas/entities/common.py b/src/snowflake/cli/api/project/schemas/entities/common.py index 5b50ac1f3..262f34580 100644 --- a/src/snowflake/cli/api/project/schemas/entities/common.py +++ b/src/snowflake/cli/api/project/schemas/entities/common.py @@ -17,8 +17,7 @@ from abc import ABC from typing import Generic, List, Optional, TypeVar -from pydantic import AliasChoices, Field, GetCoreSchemaHandler, ValidationInfo -from pydantic_core import core_schema +from pydantic import Field from snowflake.cli.api.project.schemas.native_app.application import ( ApplicationPostDeployHook, ) @@ -45,7 +44,7 @@ class MetaField(UpdatableModel): class DefaultsField(UpdatableModel): schema_: Optional[str] = Field( title="Schema.", - validation_alias=AliasChoices("schema"), + alias="schema", default=None, ) stage: Optional[str] = Field( @@ -65,21 +64,15 @@ def get_type(cls) -> str: TargetType = TypeVar("TargetType") -class TargetField(Generic[TargetType]): - def __init__(self, entity_target_key: str): - self.value = entity_target_key - - def __repr__(self): - return self.value - - @classmethod - def validate(cls, value: str, info: ValidationInfo) -> TargetField: - return cls(value) +class TargetField(UpdatableModel, Generic[TargetType]): + target: str = Field( + title="Reference to a target entity", + ) - @classmethod - def __get_pydantic_core_schema__( - cls, source_type, handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - return core_schema.with_info_after_validator_function( - cls.validate, handler(str), field_name=handler.field_name - ) + def get_type(self) -> type: + """ + Returns the generic type of this class, indicating the entity type. + Pydantic extracts Generic annotations, and populates + them in __pydantic_generic_metadata__ + """ + return self.__pydantic_generic_metadata__["args"][0] diff --git a/src/snowflake/cli/api/project/schemas/native_app/application.py b/src/snowflake/cli/api/project/schemas/native_app/application.py index cd393dec7..84f05066a 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/application.py +++ b/src/snowflake/cli/api/project/schemas/native_app/application.py @@ -52,3 +52,11 @@ class Application(UpdatableModel): title="Actions that will be executed after the application object is created/upgraded", default=None, ) + + +class ApplicationV11(Application): + # Templated defaults only supported in v1.1+ + name: Optional[str] = Field( + title="Name of the application object created when you run the snow app run command", + default="<% fn.id_concat(ctx.native_app.name, '_', fn.clean_id(fn.get_username('unknown_user'))) %>", + ) diff --git a/src/snowflake/cli/api/project/schemas/native_app/native_app.py b/src/snowflake/cli/api/project/schemas/native_app/native_app.py index 6ac648911..d12ccaa17 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/native_app.py +++ b/src/snowflake/cli/api/project/schemas/native_app/native_app.py @@ -18,8 +18,11 @@ from typing import List, Optional, Union from pydantic import Field, field_validator -from snowflake.cli.api.project.schemas.native_app.application import Application -from snowflake.cli.api.project.schemas.native_app.package import Package +from snowflake.cli.api.project.schemas.native_app.application import ( + Application, + ApplicationV11, +) +from snowflake.cli.api.project.schemas.native_app.package import Package, PackageV11 from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel from snowflake.cli.api.project.util import ( @@ -80,3 +83,11 @@ def transform_artifacts( transformed_artifacts.append(PathMapping(src=artifact)) return transformed_artifacts + + +class NativeAppV11(NativeApp): + # templated defaults are only supported with version 1.1+ + package: Optional[PackageV11] = Field(title="PackageSchema", default=PackageV11()) + application: Optional[ApplicationV11] = Field( + title="Application info", default=ApplicationV11() + ) diff --git a/src/snowflake/cli/api/project/schemas/native_app/package.py b/src/snowflake/cli/api/project/schemas/native_app/package.py index f62d4ca55..e696ed4de 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/package.py +++ b/src/snowflake/cli/api/project/schemas/native_app/package.py @@ -53,3 +53,11 @@ def validate_scripts(cls, input_list): "package.scripts field should contain unique values. Check the list for duplicates and try again" ) return input_list + + +class PackageV11(Package): + # Templated defaults only supported in v1.1+ + name: Optional[str] = IdentifierField( + title="Name of the application package created when you run the snow app run command", + default="<% fn.id_concat(ctx.native_app.name, '_pkg_', fn.clean_id(fn.get_username('unknown_user'))) %>", + ) diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 03385b79c..74ea584bf 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -32,11 +32,13 @@ Entity, v2_entity_types_map, ) -from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp +from snowflake.cli.api.project.schemas.native_app.native_app import ( + NativeApp, + NativeAppV11, +) from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel -from snowflake.cli.api.utils.models import ProjectEnvironment from snowflake.cli.api.utils.types import Context from typing_extensions import Annotated @@ -99,22 +101,14 @@ class DefinitionV10(_ProjectDefinitionBase): class DefinitionV11(DefinitionV10): - env: Union[Dict[str, str], ProjectEnvironment, None] = Field( - title="Environment specification for this project.", + native_app: Optional[NativeAppV11] = Field( + title="Native app definitions for the project", default=None + ) + env: Optional[Dict[str, Union[str, int, bool]]] = Field( + title="Default environment specification for this project.", default=None, - validation_alias="env", - union_mode="smart", ) - @field_validator("env") - @classmethod - def _convert_env( - cls, env: Union[Dict, ProjectEnvironment, None] - ) -> ProjectEnvironment: - if isinstance(env, ProjectEnvironment): - return env - return ProjectEnvironment(default_env=(env or {}), override_env={}) - class DefinitionV20(_ProjectDefinitionBase): entities: Dict[str, Annotated[Entity, Field(discriminator="type")]] = Field( @@ -147,10 +141,10 @@ def validate_entities(cls, entities: Dict[str, Entity]) -> Dict[str, Entity]: for key, entity in entities.items(): # TODO Automatically detect TargetFields to validate if entity.type == ApplicationEntity.get_type(): - if isinstance(entity.from_.target, TargetField): - target_key = str(entity.from_.target) - target_class = entity.from_.__class__.model_fields["target"] - target_type = target_class.annotation.__args__[0] + if isinstance(entity.from_, TargetField): + target_key = entity.from_.target + target_object = entity.from_ + target_type = target_object.get_type() cls._validate_target_field(target_key, target_type, entities) return entities @@ -160,37 +154,26 @@ def _validate_target_field( ): if target_key not in entities: raise ValueError(f"No such target: {target_key}") - else: - # Validate the target type - actual_target_type = entities[target_key].__class__ - if target_type and target_type is not actual_target_type: - raise ValueError( - f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" - ) + + # Validate the target type + actual_target_type = entities[target_key].__class__ + if target_type and target_type is not actual_target_type: + raise ValueError( + f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}" + ) defaults: Optional[DefaultsField] = Field( title="Default key/value entity values that are merged recursively for each entity.", default=None, ) - env: Union[Dict[str, str], ProjectEnvironment, None] = Field( - title="Environment specification for this project.", + env: Optional[Dict[str, Union[str, int, bool]]] = Field( + title="Default environment specification for this project.", default=None, - validation_alias="env", - union_mode="smart", ) - @field_validator("env") - @classmethod - def _convert_env( - cls, env: Union[Dict, ProjectEnvironment, None] - ) -> ProjectEnvironment: - if isinstance(env, ProjectEnvironment): - return env - return ProjectEnvironment(default_env=(env or {}), override_env={}) - -def build_project_definition(**data): +def build_project_definition(**data) -> ProjectDefinition: """ Returns a ProjectDefinition instance with a version matching the provided definition_version value """ diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py index 6039b7736..103c54906 100644 --- a/src/snowflake/cli/api/project/schemas/snowpark/callable.py +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -19,9 +19,7 @@ from pydantic import Field, field_validator from snowflake.cli.api.project.schemas.identifier_model import ObjectIdentifierModel from snowflake.cli.api.project.schemas.snowpark.argument import Argument -from snowflake.cli.api.project.schemas.updatable_model import ( - UpdatableModel, -) +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel class _CallableBase(UpdatableModel): diff --git a/src/snowflake/cli/api/project/schemas/updatable_model.py b/src/snowflake/cli/api/project/schemas/updatable_model.py index be0ba4b14..56b4445bb 100644 --- a/src/snowflake/cli/api/project/schemas/updatable_model.py +++ b/src/snowflake/cli/api/project/schemas/updatable_model.py @@ -14,17 +14,149 @@ from __future__ import annotations -from typing import Any, Dict +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Dict, Iterator, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationInfo, + field_validator, +) +from pydantic.fields import FieldInfo from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH +PROJECT_TEMPLATE_START = "<%" + + +def _is_templated(info: ValidationInfo, value: Any) -> bool: + return ( + info.context + and info.context.get("skip_validation_on_templates", False) + and isinstance(value, str) + and PROJECT_TEMPLATE_START in value + ) + + +_initial_context: ContextVar[Optional[Dict[str, Any]]] = ContextVar( + "_init_context_var", default=None +) + + +@contextmanager +def context(value: Dict[str, Any]) -> Iterator[None]: + """ + Thread safe context for Pydantic. + By using `with context()`, you ensure context changes apply + to the with block only + """ + token = _initial_context.set(value) + try: + yield + finally: + _initial_context.reset(token) + class UpdatableModel(BaseModel): model_config = ConfigDict(validate_assignment=True, extra="forbid") - def __init__(self, *args, **kwargs): - super().__init__(**kwargs) + def __init__(self, /, **data: Any) -> None: + """ + Pydantic provides 2 options to pass in context: + 1) Through `model_validate()` as a second argument. + 2) Through a custom init method and the use of ContextVar + + We decided not to use 1) because it silently stops working + if someone adds a pass through __init__ to any of the Pydantic models. + + We decided to go with 2) as the safer approach. + Calling validate_python() in the __init__ is how we can pass context + on initialization according to Pydantic's documentation: + https://docs.pydantic.dev/latest/concepts/validators/#using-validation-context-with-basemodel-initialization + """ + self.__pydantic_validator__.validate_python( + data, + self_instance=self, + context=_initial_context.get(), + ) + + @classmethod + def _is_entity_type_field(cls, field: Any) -> bool: + """ + Checks if a field is of type `DiscriminatorField` + """ + if not isinstance(field, FieldInfo) or not field.json_schema_extra: + return False + + return ( + "is_discriminator_field" in field.json_schema_extra + and field.json_schema_extra["is_discriminator_field"] + ) + + @classmethod + def __init_subclass__(cls, **kwargs): + """ + This method will collect all the Pydantic annotations for the class + currently being initialized (any subclass of `UpdatableModel`). + + It will add a field validator wrapper for every Pydantic field + in order to skip validation when templates are found. + + It will apply this to all Pydantic fields, except for fields + marked as `DiscriminatorField`. These will be skipped because + Pydantic does not support validators for discriminator field types. + """ + + super().__init_subclass__(**kwargs) + + field_annotations = {} + field_values = {} + # Go through the inheritance classes and collect all the annotations and + # all the values of the class attributes. We go in reverse order so that + # values in subclasses overrides values from parent classes in case of field overrides. + + for class_ in reversed(cls.__mro__): + class_dict = class_.__dict__ + field_annotations.update(class_dict.get("__annotations__", {})) + + if "model_fields" in class_dict: + # This means the class dict has already been processed by Pydantic + # All fields should properly be populated in model_fields + field_values.update(class_dict["model_fields"]) + else: + # If Pydantic did not process this class yet, get the values from class_dict directly + field_values.update(class_dict) + + # Add Pydantic validation wrapper around all fields except `DiscriminatorField`s + for field_name in field_annotations: + if not cls._is_entity_type_field(field_values.get(field_name)): + cls._add_validator(field_name) + + @classmethod + def _add_validator(cls, field_name: str): + """ + Adds a Pydantic validator with mode=wrap for the provided `field_name`. + During validation, this will check if the field is templated (not expanded yet) + and in that case, it will skip all the remaining Pydantic validation on that field. + + Since this validator is added last, it will skip all the other field validators + defined in the subclasses when templates are found. + + This logic on templates only applies when context contains `skip_validation_on_templates` flag. + """ + + def validator_skipping_templated_str(cls, value, handler, info: ValidationInfo): + if _is_templated(info, value): + return value + return handler(value) + + setattr( + cls, + f"_field_validator_with_verbose_name_to_avoid_name_conflict_{field_name}", + field_validator(field_name, mode="wrap")(validator_skipping_templated_str), + ) def update_from_dict(self, update_values: Dict[str, Any]): """ @@ -47,5 +179,16 @@ def update_from_dict(self, update_values: Dict[str, Any]): return self -def IdentifierField(*args, **kwargs): # noqa +def DiscriminatorField(*args, **kwargs): # noqa N802 + """ + Use this type for discriminator fields used for differentiating + between different entity types. + + When this `DiscriminatorField` is used on a pydantic attribute, + we will not allow templating on it. + """ + return Field(is_discriminator_field=True, *args, **kwargs) + + +def IdentifierField(*args, **kwargs): # noqa N802 return Field(max_length=254, pattern=IDENTIFIER_NO_LENGTH, *args, **kwargs) diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 4abe6712e..2ea1bdd55 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -17,7 +17,7 @@ import codecs import os import re -from typing import Optional +from typing import List, Optional from urllib.parse import quote IDENTIFIER = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]{0,254}))' @@ -88,6 +88,18 @@ def is_valid_object_name(name: str, max_depth=2, allow_quoted=True) -> bool: return re.fullmatch(pattern, name) is not None +def to_quoted_identifier(input_value: str) -> str: + """ + Turn the input into a valid quoted identifier. + If it is already a valid quoted identifier, + return it as is. + """ + if is_valid_quoted_identifier(input_value): + return input_value + + return '"' + input_value.replace('"', '""') + '"' + + def to_identifier(name: str) -> str: """ Converts a name to a valid Snowflake identifier. If the name is already a valid @@ -96,8 +108,15 @@ def to_identifier(name: str) -> str: if is_valid_identifier(name): return name - # double quote the identifier - return '"' + name.replace('"', '""') + '"' + return to_quoted_identifier(name) + + +def identifier_to_str(identifier: str) -> str: + if is_valid_quoted_identifier(identifier): + unquoted_id = identifier[1:-1] + return unquoted_id.replace('""', '"') + + return identifier def append_to_identifier(identifier: str, suffix: str) -> str: @@ -183,6 +202,27 @@ def get_env_username() -> Optional[str]: return first_set_env("USER", "USERNAME", "LOGNAME") +def concat_identifiers(identifiers: list[str]) -> str: + """ + Concatenate multiple identifiers. + If any of them is quoted identifier or contains unsafe characters, quote the result. + Otherwise, the resulting identifier will be unquoted. + """ + quotes_found = False + stringified_identifiers: List[str] = [] + + for identifier in identifiers: + if is_valid_quoted_identifier(identifier): + quotes_found = True + stringified_identifiers.append(identifier_to_str(identifier)) + + concatenated_ids_str = "".join(stringified_identifiers) + if quotes_found: + return to_quoted_identifier(concatenated_ids_str) + + return to_identifier(concatenated_ids_str) + + SUPPORTED_VERSIONS = [1] diff --git a/src/snowflake/cli/api/rendering/jinja.py b/src/snowflake/cli/api/rendering/jinja.py index c37948c1c..299cb8ac8 100644 --- a/src/snowflake/cli/api/rendering/jinja.py +++ b/src/snowflake/cli/api/rendering/jinja.py @@ -24,6 +24,7 @@ from snowflake.cli.api.secure_path import UNLIMITED, SecurePath CONTEXT_KEY = "ctx" +FUNCTION_KEY = "fn" def read_file_content(file_name: str): diff --git a/src/snowflake/cli/api/rendering/sql_templates.py b/src/snowflake/cli/api/rendering/sql_templates.py index fd77eb585..b2eea68e7 100644 --- a/src/snowflake/cli/api/rendering/sql_templates.py +++ b/src/snowflake/cli/api/rendering/sql_templates.py @@ -21,12 +21,14 @@ from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.rendering.jinja import ( CONTEXT_KEY, + FUNCTION_KEY, IgnoreAttrEnvironment, env_bootstrap, ) _SQL_TEMPLATE_START = "&{" _SQL_TEMPLATE_END = "}" +RESERVED_KEYS = [CONTEXT_KEY, FUNCTION_KEY] def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None): @@ -46,10 +48,12 @@ def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None): def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: data = data or {} - if CONTEXT_KEY in data: - raise ClickException( - f"{CONTEXT_KEY} in user defined data. The `{CONTEXT_KEY}` variable is reserved for CLI usage." - ) + + for reserved_key in RESERVED_KEYS: + if reserved_key in data: + raise ClickException( + f"{reserved_key} in user defined data. The `{reserved_key}` variable is reserved for CLI usage." + ) context_data = get_cli_context().template_context context_data.update(data) diff --git a/src/snowflake/cli/api/utils/definition_rendering.py b/src/snowflake/cli/api/utils/definition_rendering.py index e468c12a1..a3bcd990d 100644 --- a/src/snowflake/cli/api/utils/definition_rendering.py +++ b/src/snowflake/cli/api/utils/definition_rendering.py @@ -25,6 +25,7 @@ ProjectProperties, build_project_definition, ) +from snowflake.cli.api.project.schemas.updatable_model import context from snowflake.cli.api.rendering.jinja import CONTEXT_KEY from snowflake.cli.api.rendering.project_definition_templates import ( get_project_definition_cli_jinja_env, @@ -32,6 +33,7 @@ from snowflake.cli.api.utils.dict_utils import traverse from snowflake.cli.api.utils.graph import Graph, Node from snowflake.cli.api.utils.models import ProjectEnvironment +from snowflake.cli.api.utils.templating_functions import get_templating_functions from snowflake.cli.api.utils.types import Context, Definition @@ -81,7 +83,16 @@ def _get_referenced_vars( all_referenced_vars.add(TemplateVar(current_attr_chain)) current_attr_chain = None elif ( - not isinstance(ast_node, (nodes.Template, nodes.TemplateData, nodes.Output)) + not isinstance( + ast_node, + ( + nodes.Template, + nodes.TemplateData, + nodes.Output, + nodes.Call, + nodes.Const, + ), + ) or current_attr_chain is not None ): raise InvalidTemplate(f"Unexpected templating syntax in {template_value}") @@ -199,7 +210,6 @@ def _build_dependency_graph( dependencies_graph = Graph[TemplateVar]() for variable in all_vars: dependencies_graph.add(Node[TemplateVar](key=variable.key, data=variable)) - for variable in all_vars: # If variable is found in os.environ or from cli override, then use the value as is # skip rendering by pre-setting the rendered_value attribute @@ -262,6 +272,17 @@ def _template_version_warning(): ) +def _add_defaults_to_definition(definition: Definition) -> Definition: + with context({"skip_validation_on_templates": True}): + # pass a flag to Pydantic to skip validation for templated scalars + # populate the defaults + project_definition = build_project_definition(**definition) + + return project_definition.model_dump( + exclude_none=True, warnings=False, by_alias=True + ) + + def render_definition_template( original_definition: Optional[Definition], context_overrides: Context ) -> ProjectProperties: @@ -276,11 +297,14 @@ def render_definition_template( Environment variables take precedence during the rendering process. """ - # protect input from update + # copy input to protect it from update definition = copy.deepcopy(original_definition) - # start with an environment from overrides and environment variables: + # collect all the override --env variables passed through CLI input override_env = context_overrides.get(CONTEXT_KEY, {}).get("env", {}) + + # set up Project Environment with empty default_env because + # default env section from project definition file is still templated at this time environment_overrides = ProjectEnvironment( default_env={}, override_env=override_env ) @@ -288,7 +312,6 @@ def render_definition_template( if definition is None: return ProjectProperties(None, {CONTEXT_KEY: {"env": environment_overrides}}) - project_context = {CONTEXT_KEY: definition} template_env = TemplatedEnvironment(get_project_definition_cli_jinja_env()) if "definition_version" not in definition or Version( @@ -304,12 +327,18 @@ def render_definition_template( # also warn on Exception, as it means the user is incorrectly attempting to use templating _template_version_warning() - project_definition = build_project_definition(**original_definition) + project_definition = build_project_definition(**definition) + project_context = {CONTEXT_KEY: definition} project_context[CONTEXT_KEY]["env"] = environment_overrides return ProjectProperties(project_definition, project_context) - default_env = definition.get("env", {}) - _validate_env_section(default_env) + definition = _add_defaults_to_definition(definition) + project_context = {CONTEXT_KEY: definition} + + _validate_env_section(definition.get("env", {})) + + # add available templating functions + project_context["fn"] = get_templating_functions() referenced_vars = _get_referenced_vars_in_definition(template_env, definition) @@ -338,7 +367,11 @@ def on_cycle_action(node: Node[TemplateVar]): update_action=lambda val: template_env.render(val, final_context), ) - definition["env"] = ProjectEnvironment(default_env, override_env) - project_context[CONTEXT_KEY] = definition project_definition = build_project_definition(**definition) + project_context[CONTEXT_KEY] = definition + # Use `ProjectEnvironment` in project context in order to + # handle env variables overrides from OS env and from CLI arguments. + project_context[CONTEXT_KEY]["env"] = ProjectEnvironment( + default_env=project_context[CONTEXT_KEY].get("env"), override_env=override_env + ) return ProjectProperties(project_definition, project_context) diff --git a/src/snowflake/cli/api/utils/models.py b/src/snowflake/cli/api/utils/models.py index ac24b9347..b6a1d2167 100644 --- a/src/snowflake/cli/api/utils/models.py +++ b/src/snowflake/cli/api/utils/models.py @@ -15,12 +15,12 @@ from __future__ import annotations import os +from dataclasses import dataclass from typing import Any, Dict, Optional -from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel - -class ProjectEnvironment(UpdatableModel): +@dataclass +class ProjectEnvironment: """ This class handles retrieval of project env variables. These env variables can be accessed through templating, as ctx.env. @@ -31,13 +31,16 @@ class ProjectEnvironment(UpdatableModel): - Check for default values from the project definition file. """ - override_env: Dict[str, Any] = {} - default_env: Dict[str, Any] = {} + override_env: Dict[str, Any] + default_env: Dict[str, Any] def __init__( - self, default_env: Dict[str, Any], override_env: Optional[Dict[str, Any]] = None + self, + default_env: Optional[Dict[str, Any]] = None, + override_env: Optional[Dict[str, Any]] = None, ): - super().__init__(self, default_env=default_env, override_env=override_env or {}) + self.override_env = override_env or {} + self.default_env = default_env or {} def __getitem__(self, item): if item in self.override_env: diff --git a/src/snowflake/cli/api/utils/templating_functions.py b/src/snowflake/cli/api/utils/templating_functions.py new file mode 100644 index 000000000..e43907da9 --- /dev/null +++ b/src/snowflake/cli/api/utils/templating_functions.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List, Optional + +from snowflake.cli.api.exceptions import InvalidTemplate +from snowflake.cli.api.project.util import ( + clean_identifier, + concat_identifiers, + get_env_username, + identifier_to_str, + to_identifier, +) + + +class TemplatingFunctions: + """ + This class contains all the functions available for templating. + Any callable not starting with '_' will automatically be available for users to use. + """ + + @staticmethod + def _verify_str_arguments( + func_name: str, + args: List[Any], + *, + min_count: Optional[int] = None, + max_count: Optional[int] = None, + ): + if min_count is not None and len(args) < min_count: + raise InvalidTemplate( + f"{func_name} requires at least {min_count} argument(s)" + ) + + if max_count is not None and len(args) > max_count: + raise InvalidTemplate( + f"{func_name} supports at most {max_count} argument(s)" + ) + + for arg in args: + if not isinstance(arg, str): + raise InvalidTemplate(f"{func_name} only accepts String values") + + @staticmethod + def id_concat(*args): + """ + input: one or more string arguments (SQL ID or plain String). + output: a valid SQL ID (quoted or unquoted) + + Takes on multiple String arguments and concatenate them into one String. + If any of the Strings is a valid quoted ID, it will be unescaped for the concatenation process. + The resulting String is then escaped and quoted if: + - It contains non SQL safe characters + - Any of the input was a valid quoted identifier. + """ + TemplatingFunctions._verify_str_arguments("id_concat", args, min_count=1) + return concat_identifiers(args) + + @staticmethod + def str_to_id(*args): + """ + input: one string argument. (SQL ID or plain String) + output: a valid SQL ID (quoted or unquoted) + + If the input is a valid quoted or valid unquoted identifier, return it as is. + Otherwise, if the input contains unsafe characters and is not properly quoted, + then escape it and quote it. + """ + TemplatingFunctions._verify_str_arguments( + "str_to_id", args, min_count=1, max_count=1 + ) + return to_identifier(args[0]) + + @staticmethod + def id_to_str(*args): + """ + input: one string argument (SQL ID or plain String). + output: a plain string + + If the input is a valid SQL ID, then unescape it and return the plain String version. + Otherwise, return the input as is. + """ + TemplatingFunctions._verify_str_arguments( + "id_to_str", args, min_count=1, max_count=1 + ) + return identifier_to_str(args[0]) + + @staticmethod + def get_username(*args): + """ + input: one optional string containing the fallback value + output: current username detected from the Operating System + + If the current username is not found or is blank, return blank + or use the fallback value if provided. + """ + TemplatingFunctions._verify_str_arguments( + "get_username", args, min_count=0, max_count=1 + ) + fallback_username = args[0] if len(args) > 0 else "" + return get_env_username() or fallback_username + + @staticmethod + def clean_id(*args): + """ + input: one string argument + output: a valid non-quoted SQL ID + + Removes any unsafe SQL characters from the input, lowercase it, + and return it as a valid unquoted SQL ID. + """ + TemplatingFunctions._verify_str_arguments( + "clean_id", args, min_count=1, max_count=1 + ) + + return clean_identifier(args[0]) + + +def get_templating_functions(): + """ + Returns a dictionary with all the functions available for templating + """ + templating_functions = { + func: getattr(TemplatingFunctions, func) + for func in dir(TemplatingFunctions) + if callable(getattr(TemplatingFunctions, func)) and not func.startswith("_") + } + + return templating_functions diff --git a/src/snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py b/src/snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py index 20b4cb859..d552ca56d 100644 --- a/src/snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py +++ b/src/snowflake/cli/plugins/nativeapp/v2_conversions/v2_to_v1_decorator.py @@ -37,6 +37,7 @@ DefinitionV11, DefinitionV20, ) +from snowflake.cli.api.utils.definition_rendering import render_definition_template def _convert_v2_artifact_to_v1_dict( @@ -135,8 +136,9 @@ def _pdf_v2_to_v1(v2_definition: DefinitionV20) -> DefinitionV11: "post_deploy" ] = app_definition.meta.post_deploy + result = render_definition_template(pdfv1, {}) # Override the definition object in global context - return DefinitionV11(**pdfv1) + return result.project_definition def nativeapp_definition_v2_to_v1(func): diff --git a/tests/api/project/schemas/test_updatable_model.py b/tests/api/project/schemas/test_updatable_model.py new file mode 100644 index 000000000..ef413da27 --- /dev/null +++ b/tests/api/project/schemas/test_updatable_model.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError, field_validator +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel, context + + +def test_updatable_model_including_other_models(): + class TestIncludedModel(UpdatableModel): + c: str + + class TestModel(UpdatableModel): + a: str + b: TestIncludedModel + + test_input = {"a": "a_value", "b": {"c": "c_value"}} + result = TestModel(**test_input) + + assert result.a == "a_value" + assert result.b is not None + assert result.b.c == "c_value" + + +def test_updatable_model_with_sub_class_models(): + class ParentModel(UpdatableModel): + a: str + + class ChildModel(ParentModel): + a: str + b: str + + test_input = {"a": "a_value", "b": "b_value"} + result = ChildModel(**test_input) + + assert result.a == "a_value" + assert result.b == "b_value" + + +def test_updatable_model_with_validators(): + class TestModel(UpdatableModel): + a: str + + @field_validator("a", mode="before") + @classmethod + def validate_a_before(cls, value): + if value != "expected_value": + raise ValueError("Invalid Value") + return value + + @field_validator("a", mode="after") + @classmethod + def validate_a_after(cls, value): + if value != "expected_value": + raise ValueError("Invalid Value") + return value + + @field_validator("a", mode="wrap") + @classmethod + def validate_a_wrap(cls, value, handler): + if value != "expected_value": + raise ValueError("Invalid Value") + result = handler(value) + if result != "expected_value": + raise ValueError("Invalid Value") + return result + + result = TestModel(a="expected_value") + + assert result.a == "expected_value" + + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Invalid Value" in str(e.value) + + with pytest.raises(ValidationError) as e: + TestModel(a="<% sometemplate %>") + assert "Invalid Value" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Invalid Value" in str(e.value) + + result = TestModel(a="<% sometemplate %>") + assert result.a == "<% sometemplate %>" + + +def test_updatable_model_with_plain_validator(): + class TestModel(UpdatableModel): + a: str + + @field_validator("a", mode="plain") + @classmethod + def validate_a_plain(cls, value): + if value != "expected_value": + raise ValueError("Invalid Value") + return value + + result = TestModel(a="expected_value") + assert result.a == "expected_value" + + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Invalid Value" in str(e.value) + + with pytest.raises(ValidationError) as e: + TestModel(a="<% sometemplate %>") + assert "Invalid Value" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Invalid Value" in str(e.value) + + result = TestModel(a="<% sometemplate %>") + assert result.a == "<% sometemplate %>" + + +def test_updatable_model_with_int_and_templates(): + class TestModel(UpdatableModel): + a: int + + result = TestModel(a="123") + assert result.a == 123 + + with pytest.raises(ValidationError) as e: + TestModel(a="<% sometemplate %>") + assert "Input should be a valid integer" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Input should be a valid integer" in str(e.value) + + result = TestModel(a="<% sometemplate %>") + assert result.a == "<% sometemplate %>" + + +def test_updatable_model_with_bool_and_templates(): + class TestModel(UpdatableModel): + a: bool + + result = TestModel(a="true") + assert result.a is True + + with pytest.raises(ValidationError) as e: + TestModel(a="<% sometemplate %>") + assert "Input should be a valid boolean" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + TestModel(a="abc") + assert "Input should be a valid boolean" in str(e.value) + + result = TestModel(a="<% sometemplate %>") + assert result.a == "<% sometemplate %>" + + +def test_updatable_model_with_sub_classes_and_template_values(): + class ParentModel(UpdatableModel): + a: str + + class ChildModel(ParentModel): + b: int + + result = ChildModel(a="any_value", b="123") + assert result.b == 123 + + with pytest.raises(ValidationError) as e: + ChildModel(a="any_value", b="<% sometemplate %>") + assert "Input should be a valid integer" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + ChildModel(a="any_value", b="abc") + assert "Input should be a valid integer" in str(e.value) + + result = ChildModel(a="any_value", b="<% sometemplate %>") + assert result.b == "<% sometemplate %>" + + +def test_updatable_model_with_sub_classes_and_template_values_and_custom_validator_in_parent(): + class ParentModel(UpdatableModel): + a: str + + @field_validator("a", mode="before") + @classmethod + def validate_a_before(cls, value): + if value != "expected_value": + raise ValueError("Invalid Value") + return value + + class ChildModel(ParentModel): + b: str + + result = ChildModel(a="expected_value", b="any_value") + assert result.a == "expected_value" + + with pytest.raises(ValidationError) as e: + ChildModel(a="<% sometemplate %>", b="any_value") + assert "Invalid Value" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + ChildModel(a="abc", b="any_value") + assert "Invalid Value" in str(e.value) + + result = ChildModel(a="<% sometemplate %>", b="any_value") + assert result.a == "<% sometemplate %>" + + +def test_updatable_model_with_sub_classes_and_template_values_and_custom_validator_in_child(): + class ParentModel(UpdatableModel): + a: str + + class ChildModel(ParentModel): + b: str + + @field_validator("b", mode="before") + @classmethod + def validate_b_before(cls, value): + if value != "expected_value": + raise ValueError("Invalid Value") + return value + + result = ChildModel(a="any_value", b="expected_value") + assert result.b == "expected_value" + + with pytest.raises(ValidationError) as e: + ChildModel(a="any_value", b="<% sometemplate %>") + assert "Invalid Value" in str(e.value) + + with context({"skip_validation_on_templates": True}): + with pytest.raises(ValidationError) as e: + ChildModel(a="any_value", b="abc") + assert "Invalid Value" in str(e.value) + + result = ChildModel(a="any_value", b="<% sometemplate %>") + assert result.b == "<% sometemplate %>" diff --git a/tests/api/test_rendering.py b/tests/api/test_rendering.py index a4bcbf9a5..0829d89da 100644 --- a/tests/api/test_rendering.py +++ b/tests/api/test_rendering.py @@ -16,6 +16,7 @@ from unittest import mock import pytest +from click import ClickException from jinja2 import UndefinedError from snowflake.cli.api.rendering.sql_templates import snowflake_sql_jinja_render from snowflake.cli.api.utils.models import ProjectEnvironment @@ -96,6 +97,22 @@ def test_that_undefined_variables_raise_error(text, cli_context): snowflake_sql_jinja_render(text) +@pytest.mark.parametrize( + "key_word", + [ + "ctx", + "fn", + ], +) +def test_reserved_keywords_raise_error(key_word, cli_context): + with pytest.raises(ClickException) as err: + snowflake_sql_jinja_render("select 1;", data={key_word: "some_value"}) + assert ( + err.value.message + == f"{key_word} in user defined data. The `{key_word}` variable is reserved for CLI usage." + ) + + @mock.patch.dict(os.environ, {"TEST_ENV_VAR": "foo"}) def test_contex_can_access_environment_variable(cli_context): assert snowflake_sql_jinja_render("&{ ctx.env.TEST_ENV_VAR }") == os.environ.get( diff --git a/tests/api/utils/test_definition_rendering.py b/tests/api/utils/test_definition_rendering.py index a38d7b94d..4da047043 100644 --- a/tests/api/utils/test_definition_rendering.py +++ b/tests/api/utils/test_definition_rendering.py @@ -22,6 +22,7 @@ from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.utils.definition_rendering import render_definition_template from snowflake.cli.api.utils.models import ProjectEnvironment +from snowflake.cli.api.utils.templating_functions import get_templating_functions from tests.nativeapp.utils import NATIVEAPP_MODULE @@ -40,12 +41,14 @@ def test_resolve_variables_in_project_no_cross_variable_dependencies(): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - {"number": 1, "string": "foo", "boolean": True}, {} + default_env={"number": 1, "string": "foo", "boolean": True}, + override_env={}, ), - } + }, } @@ -62,10 +65,13 @@ def test_resolve_variables_in_project_cross_variable_dependencies(): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", - "env": ProjectEnvironment({"A": 42, "B": "b=42", "C": "b=42 and 42"}, {}), - } + "env": ProjectEnvironment( + default_env={"A": 42, "B": "b=42", "C": "b=42 and 42"}, override_env={} + ), + }, } @@ -96,7 +102,7 @@ def test_no_resolve_and_warning_in_version_1(warning_mock): "ctx": { "definition_version": "1", "native_app": {"name": "test_source_<% ctx.env.A %>", "artifacts": []}, - "env": ProjectEnvironment({}, {}), + "env": ProjectEnvironment(default_env={}, override_env={}), } } warning_mock.assert_called_once_with( @@ -118,7 +124,7 @@ def test_partial_invalid_template_in_version_1(warning_mock): "ctx": { "definition_version": "1", "native_app": {"name": "test_source_<% ctx.env.A", "artifacts": []}, - "env": ProjectEnvironment({}, {}), + "env": ProjectEnvironment(default_env={}, override_env={}), } } # we still want to warn if there was an incorrect attempt to use templating @@ -128,7 +134,7 @@ def test_partial_invalid_template_in_version_1(warning_mock): ) -@mock.patch.dict(os.environ, {"A": "value"}, clear=True) +@mock.patch.dict(os.environ, {"A": "value", "USER": "username"}, clear=True) @mock.patch(f"{NATIVEAPP_MODULE}.cc.warning") def test_no_warning_in_version_1_1(warning_mock): definition = { @@ -138,11 +144,25 @@ def test_no_warning_in_version_1_1(warning_mock): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", - "native_app": {"name": "test_source_value", "artifacts": []}, - "env": ProjectEnvironment({}, {}), - } + "native_app": { + "name": "test_source_value", + "artifacts": [], + "bundle_root": "output/bundle/", + "deploy_root": "output/deploy/", + "generated_root": "__generated/", + "scratch_stage": "app_src.stage_snowflake_cli_scratch", + "source_stage": "app_src.stage", + "package": { + "name": "test_source_value_pkg_username", + "distribution": "internal", + }, + "application": {"name": "test_source_value_username"}, + }, + "env": ProjectEnvironment(default_env={}, override_env={}), + }, } warning_mock.assert_not_called() @@ -172,16 +192,20 @@ def test_resolve_variables_in_project_cross_project_dependencies(): } result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", - "streamlit": {"name": "my_app"}, + "streamlit": { + "name": "my_app", + "main_file": "streamlit_app.py", + "query_warehouse": "streamlit", + "stage": "streamlit", + }, "env": ProjectEnvironment( - { - "app": "name of streamlit is my_app", - }, - {}, + default_env={"app": "name of streamlit is my_app"}, + override_env={}, ), - } + }, } @@ -207,17 +231,18 @@ def test_resolve_variables_in_project_environment_variables_precedence(): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - { + default_env={ "should_be_replaced_by_env": "test failed", "test_variable": "new_lowercase_value and new_uppercase_value", "test_variable_2": "this comes from os.environ", }, - {}, + override_env={}, ), - } + }, } assert result["ctx"]["env"]["lowercase"] == "new_lowercase_value" assert result["ctx"]["env"]["UPPERCASE"] == "new_uppercase_value" @@ -225,7 +250,11 @@ def test_resolve_variables_in_project_environment_variables_precedence(): assert result["ctx"]["env"]["value_from_env"] == "this comes from os.environ" -@mock.patch.dict(os.environ, {"env_var": "<% ctx.definition_version %>"}, clear=True) +@mock.patch.dict( + os.environ, + {"env_var": "<% ctx.definition_version %>", "USER": "username"}, + clear=True, +) def test_env_variables_do_not_get_resolved(): definition = { "definition_version": "1.1", @@ -235,21 +264,33 @@ def test_env_variables_do_not_get_resolved(): }, } result = render_definition_template(definition, {}).project_context - assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "native_app": { "name": "test_source_<% ctx.definition_version %>", "artifacts": [], + "bundle_root": "output/bundle/", + "deploy_root": "output/deploy/", + "generated_root": "__generated/", + "scratch_stage": "app_src.stage_snowflake_cli_scratch", + "source_stage": "app_src.stage", + "package": { + "name": '"test_source_<% ctx.definition_version %>_pkg_username"', + "distribution": "internal", + }, + "application": { + "name": '"test_source_<% ctx.definition_version %>_username"' + }, }, "env": ProjectEnvironment( - { + default_env={ "reference_to_name": "test_source_<% ctx.definition_version %>", }, - {}, + override_env={}, ), - } + }, } assert result["ctx"]["env"]["env_var"] == "<% ctx.definition_version %>" @@ -335,7 +376,7 @@ def test_resolve_variables_reference_non_scalar(definition, error_var): ) -@mock.patch.dict(os.environ, {"blank_env": ""}, clear=True) +@mock.patch.dict(os.environ, {"blank_env": "", "USER": "username"}, clear=True) def test_resolve_variables_blank_is_ok(): definition = { "definition_version": "1.1", @@ -351,16 +392,30 @@ def test_resolve_variables_blank_is_ok(): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", - "native_app": {"name": "", "deploy_root": "", "artifacts": []}, + "native_app": { + "name": "", + "deploy_root": "", + "artifacts": [], + "bundle_root": "output/bundle/", + "generated_root": "__generated/", + "scratch_stage": "app_src.stage_snowflake_cli_scratch", + "source_stage": "app_src.stage", + "package": { + "name": "_pkg_username", + "distribution": "internal", + }, + "application": {"name": "_username"}, + }, "env": ProjectEnvironment( - { + default_env={ "blank_default_env": "", }, - {}, + override_env={}, ), - } + }, } assert result["ctx"]["env"]["blank_env"] == "" @@ -419,17 +474,17 @@ def test_unquoted_template_usage_in_strings_yaml(named_temporary_file): with named_temporary_file(suffix=".yml") as p: p.write_text(dedent(text)) - project_definition = load_project([p]).project_definition + result = load_project([p]) - assert project_definition.env == ProjectEnvironment( - { + assert result.project_context.get("ctx", {}).get("env", None) == ProjectEnvironment( + default_env={ "block_multiline": "this is multiline string \nwith template Snowflake is great!\n", "flow_multiline_not_quoted": "this is multiline string with template Snowflake is great!", "flow_multiline_quoted": "this is multiline string with template Snowflake is great!", "single_line": "Snowflake is great!", "value": "Snowflake is great!", }, - {}, + override_env={}, ) @@ -444,15 +499,16 @@ def test_injected_yml_in_env_should_not_be_expanded(): result = render_definition_template(definition, {}).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - { + default_env={ "test_env": " - app/*\n - src\n", }, - {}, + override_env={}, ), - } + }, } assert result["ctx"]["env"]["var_with_yml"] == " - app/*\n - src\n" @@ -486,13 +542,10 @@ def test_invalid_type_for_env_section(): "definition_version": "1.1", "env": ["test_env", "array_val1"], } - with pytest.raises(InvalidTemplate) as err: + with pytest.raises(SchemaValidationError) as err: render_definition_template(definition, {}) - assert ( - err.value.message - == "env section in project definition file should be a mapping" - ) + assert "Input should be a valid dictionary" in err.value.message def test_invalid_type_for_env_variable(): @@ -502,13 +555,10 @@ def test_invalid_type_for_env_variable(): "test_env": ["array_val1"], }, } - with pytest.raises(InvalidTemplate) as err: + with pytest.raises(SchemaValidationError) as err: render_definition_template(definition, {}) - assert ( - err.value.message - == "Variable test_env in env section of project definition file should be a scalar" - ) + assert "Input should be a valid string" in err.value.message @mock.patch.dict(os.environ, {"env_var_test": "value_from_os_env"}, clear=True) @@ -525,16 +575,17 @@ def test_env_priority_from_cli_and_os_env_and_project_env(): ).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - { + default_env={ "env_var_test": "value_from_definition_file", "final_value": "value_from_cli_override", }, - {"env_var_test": "value_from_cli_override"}, + override_env={"env_var_test": "value_from_cli_override"}, ), - } + }, } assert result["ctx"]["env"]["env_var_test"] == "value_from_cli_override" @@ -553,13 +604,14 @@ def test_values_env_from_only_overrides(): ).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - {"final_value": "value_from_cli_override"}, - {"env_var_test": "value_from_cli_override"}, + default_env={"final_value": "value_from_cli_override"}, + override_env={"env_var_test": "value_from_cli_override"}, ), - } + }, } assert result["ctx"]["env"]["env_var_test"] == "value_from_cli_override" @@ -575,13 +627,14 @@ def test_cli_env_var_blank(): ).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - {}, - {"env_var_test": ""}, + default_env={}, + override_env={"env_var_test": ""}, ), - } + }, } assert result["ctx"]["env"]["env_var_test"] == "" @@ -597,13 +650,14 @@ def test_cli_env_var_does_not_expand_with_templating(): ).project_context assert result == { + "fn": get_templating_functions(), "ctx": { "definition_version": "1.1", "env": ProjectEnvironment( - {}, - {"env_var_test": "<% ctx.env.something %>"}, + default_env={}, + override_env={"env_var_test": "<% ctx.env.something %>"}, ), - } + }, } assert result["ctx"]["env"]["env_var_test"] == "<% ctx.env.something %>" @@ -622,11 +676,162 @@ def test_os_env_and_override_envs_in_version_1(): "ctx": { "definition_version": "1", "env": ProjectEnvironment( - {}, - {"override_env": "override_env_value"}, + default_env={}, + override_env={"override_env": "override_env_value"}, ), } } assert result["ctx"]["env"]["override_env"] == "override_env_value" assert result["ctx"]["env"]["os_env_var"] == "os_env_var_value" + + +@mock.patch.dict(os.environ, {"debug": "true", "USER": "username"}, clear=True) +def test_non_str_scalar_with_templates(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "application": {"debug": "<% ctx.env.debug %>"}, + }, + } + + result = render_definition_template(definition, {}).project_context + + assert result == { + "fn": get_templating_functions(), + "ctx": { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "bundle_root": "output/bundle/", + "deploy_root": "output/deploy/", + "generated_root": "__generated/", + "scratch_stage": "app_src.stage_snowflake_cli_scratch", + "source_stage": "app_src.stage", + "package": { + "name": "test_app_pkg_username", + "distribution": "internal", + }, + "application": { + "name": "test_app_username", + "debug": "true", + }, + }, + "env": ProjectEnvironment(default_env={}, override_env={}), + }, + } + + +@mock.patch.dict(os.environ, {"debug": "invalid boolean"}, clear=True) +def test_non_str_scalar_with_templates_with_invalid_value(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "application": {"debug": "<% ctx.env.debug %>"}, + }, + } + + with pytest.raises(SchemaValidationError) as err: + render_definition_template(definition, {}) + + assert "Input should be a valid boolean" in err.value.message + + +@mock.patch.dict(os.environ, {"stage": "app_src.stage", "USER": "username"}, clear=True) +def test_field_with_custom_validation_with_templates(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "source_stage": "<% ctx.env.stage %>", + }, + } + + result = render_definition_template(definition, {}).project_context + + assert result == { + "fn": get_templating_functions(), + "ctx": { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "bundle_root": "output/bundle/", + "deploy_root": "output/deploy/", + "generated_root": "__generated/", + "scratch_stage": "app_src.stage_snowflake_cli_scratch", + "source_stage": "app_src.stage", + "package": { + "name": "test_app_pkg_username", + "distribution": "internal", + }, + "application": { + "name": "test_app_username", + }, + }, + "env": ProjectEnvironment(default_env={}, override_env={}), + }, + } + + +@mock.patch.dict(os.environ, {"stage": "invalid stage name"}, clear=True) +def test_field_with_custom_validation_with_templates_and_invalid_value(): + definition = { + "definition_version": "1.1", + "native_app": { + "name": "test_app", + "artifacts": [], + "source_stage": "<% ctx.env.stage %>", + }, + } + + with pytest.raises(SchemaValidationError) as err: + render_definition_template(definition, {}) + + assert "Incorrect value for source_stage value of native_app" in err.value.message + + +@pytest.mark.parametrize( + "na_name, expected_app_name, expected_pkg_name", + [ + # valid unquoted ID + ("safe_name", "safe_name_username", "safe_name_pkg_username"), + # valid quoted ID with unsafe char + ('"unsafe.name"', '"unsafe.name_username"', '"unsafe.name_pkg_username"'), + # valid quoted ID with safe char + ('"safe_name"', '"safe_name_username"', '"safe_name_pkg_username"'), + # valid quoted id with double quotes char + ('"name_""_"', '"name_""__username"', '"name_""__pkg_username"'), + # unquoted ID with unsafe char + ("unsafe.name", '"unsafe.name_username"', '"unsafe.name_pkg_username"'), + ], +) +@mock.patch.dict(os.environ, {"USER": "username"}, clear=True) +def test_defaults_native_app_pkg_name( + na_name, expected_app_name: str, expected_pkg_name: str +): + + definition = { + "definition_version": "1.1", + "native_app": {"name": na_name, "artifacts": []}, + "env": { + "app_reference": "<% ctx.native_app.application.name %>", + "pkg_reference": "<% ctx.native_app.package.name %>", + }, + } + result = render_definition_template(definition, {}) + project_context = result.project_context + project_definition = result.project_definition + + assert project_definition.native_app.application.name == expected_app_name + assert project_definition.native_app.package.name == expected_pkg_name + + env = project_context.get("ctx", {}).get("env", {}) + assert env.get("app_reference") == expected_app_name + assert env.get("pkg_reference") == expected_pkg_name diff --git a/tests/api/utils/test_templating_functions.py b/tests/api/utils/test_templating_functions.py new file mode 100644 index 000000000..af575f146 --- /dev/null +++ b/tests/api/utils/test_templating_functions.py @@ -0,0 +1,360 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import pytest +from snowflake.cli.api.exceptions import InvalidTemplate +from snowflake.cli.api.utils.definition_rendering import render_definition_template +from snowflake.cli.api.utils.templating_functions import get_templating_functions + + +def test_template_unknown_function(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.unknown_func('hello') %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "Could not find template variable fn.unknown_func" in err.value.message + + +def test_available_templating_functions(): + result = get_templating_functions() + assert sorted(result.keys()) == sorted( + [ + "id_to_str", + "str_to_id", + "id_concat", + "get_username", + "clean_id", + ] + ) + + +@pytest.mark.parametrize( + "input_list, expected_output", + [ + # test concatenate a constant with a variable -> quoted + (["'first_'", "ctx.definition_version"], '"first_1.1"'), + # test concatenate valid unquoted values -> non-quoted + (["'first_'", "'second'"], "first_second"), + # test concatenate unquoted ids with unsafe chars -> quoted + (["'first.'", "'second'"], '"first.second"'), + # all safe chars, one with quoted id -> quoted + (["'first_'", "'second_'", "'\"third\"'"], '"first_second_third"'), + # one word, unsafe chars -> quoted + (["'first.'"], '"first."'), + # one word, safe chars -> non-quoted + (["'first'"], "first"), + # blank input -> quoted blank output + (["''", "''"], '""'), + ], +) +def test_id_concat_with_valid_values(input_list, expected_output): + input_list_str = ", ".join(input_list) + definition = { + "definition_version": "1.1", + "env": { + "value": f"<% fn.id_concat({input_list_str}) %>", + }, + } + + result = render_definition_template(definition, {}).project_context + env = result.get("ctx", {}).get("env", {}) + assert env.get("value") == expected_output + + +def test_id_concat_with_no_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.id_concat() %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "id_concat requires at least 1 argument(s)" in err.value.message + + +def test_id_concat_with_non_string_arg(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.id_concat(123) %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "id_concat only accepts String values" in err.value.message + + +@pytest.mark.parametrize( + "input_val, expected_output", + [ + # unquoted safe -> unchanged + ("first", "first"), + # unquoted unsafe -> unchanged + ("first.second", "first.second"), + # looks like quoted but invalid -> unchanged + ('"first"second"', '"first"second"'), + # valid quoted -> unquoted + ('"first""second"', 'first"second'), + # unquoted blank -> blank + ("", ""), + # quoted blank -> blank + ('""', ""), + ], +) +def test_id_to_str_valid_values(input_val, expected_output): + definition = { + "definition_version": "1.1", + "env": { + "input_value": input_val, + "output_value": "<% fn.id_to_str(ctx.env.input_value) %>", + }, + } + + result = render_definition_template(definition, {}).project_context + env = result.get("ctx", {}).get("env", {}) + assert env.get("output_value") == expected_output + + +def test_id_to_str_with_no_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.id_to_str() %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "id_to_str requires at least 1 argument(s)" in err.value.message + + +def test_id_to_str_with_two_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.id_to_str('a', 'b') %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "id_to_str supports at most 1 argument(s)" in err.value.message + + +def test_id_to_str_with_non_string_arg(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.id_to_str(123) %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "id_to_str only accepts String values" in err.value.message + + +@pytest.mark.parametrize( + "input_val, expected_output", + [ + # unquoted safe -> unchanged + ("first", "first"), + # unquoted unsafe -> quoted + ("first.second", '"first.second"'), + # looks like quoted but invalid -> quote it and escape + ('"first"second"', '"""first""second"""'), + # valid quoted -> unchanged + ('"first""second"', '"first""second"'), + # blank -> quoted blank + ("", '""'), + # quoted blank -> unchanged + ('""', '""'), + ], +) +def test_str_to_id_valid_values(input_val, expected_output): + definition = { + "definition_version": "1.1", + "env": { + "input_value": input_val, + "output_value": "<% fn.str_to_id(ctx.env.input_value) %>", + }, + } + + result = render_definition_template(definition, {}).project_context + env = result.get("ctx", {}).get("env", {}) + assert env.get("output_value") == expected_output + + +def test_str_to_id_with_no_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.str_to_id() %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "str_to_id requires at least 1 argument(s)" in err.value.message + + +def test_str_to_id_with_two_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.str_to_id('a', 'b') %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "str_to_id supports at most 1 argument(s)" in err.value.message + + +def test_str_to_id_with_non_string_arg(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.str_to_id(123) %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "str_to_id only accepts String values" in err.value.message + + +@pytest.mark.parametrize( + "os_environ, expected_output", + [ + ({"USER": "test_user"}, "test_user"), + ({"USERNAME": "test_user"}, "test_user"), + ({}, ""), + ], +) +def test_get_username_valid_values(os_environ, expected_output): + definition = { + "definition_version": "1.1", + "env": { + "output_value": "<% fn.get_username() %>", + }, + } + + with mock.patch.dict(os.environ, os_environ, clear=True): + result = render_definition_template(definition, {}).project_context + + env = result.get("ctx", {}).get("env", {}) + assert env.get("output_value") == expected_output + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_get_username_with_fallback_value(): + definition = { + "definition_version": "1.1", + "env": { + "output_value": "<% fn.get_username('fallback_user') %>", + }, + } + + result = render_definition_template(definition, {}).project_context + + env = result.get("ctx", {}).get("env", {}) + assert env.get("output_value") == "fallback_user" + + +def test_get_username_with_two_args_should_fail(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.get_username('a', 'b') %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "get_username supports at most 1 argument(s)" in err.value.message + + +@pytest.mark.parametrize( + "input_value, expected_output", + [ + ("test_value", "test_value"), + (" T'EST_Va l.u-e" "", "test_value"), + ("", ""), + ('""', ""), + ('"some_id"', "some_id"), + ], +) +def test_clean_id_valid_values(input_value, expected_output): + definition = { + "definition_version": "1.1", + "env": { + "input_value": input_value, + "output_value": "<% fn.clean_id(ctx.env.input_value) %>", + }, + } + + result = render_definition_template(definition, {}).project_context + + env = result.get("ctx", {}).get("env", {}) + assert env.get("output_value") == expected_output + + +def test_clean_id_with_no_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.clean_id() %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "clean_id requires at least 1 argument(s)" in err.value.message + + +def test_clean_id_with_two_args(): + definition = { + "definition_version": "1.1", + "env": { + "value": "<% fn.clean_id('a', 'b') %>", + }, + } + + with pytest.raises(InvalidTemplate) as err: + render_definition_template(definition, {}) + + assert "clean_id supports at most 1 argument(s)" in err.value.message diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 7f2a05bcd..1a6ef1700 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -14,8 +14,9 @@ from __future__ import annotations +import os from pathlib import Path -from typing import List, Optional +from typing import List from unittest import mock from unittest.mock import PropertyMock @@ -46,6 +47,7 @@ def test_napp_project_1(project_definition_files): @pytest.mark.parametrize("project_definition_files", ["minimal"], indirect=True) +@mock.patch.dict(os.environ, {"USER": "jsmith"}) def test_na_minimal_project(project_definition_files: List[Path]): project = load_project(project_definition_files).project_definition assert project.native_app.name == "minimal" @@ -54,30 +56,23 @@ def test_na_minimal_project(project_definition_files: List[Path]): PathMapping(src="README.md"), ] - from os import getenv as original_getenv - - def mock_getenv(key: str, default: Optional[str] = None) -> Optional[str]: - if key.lower() == "user": - return "jsmith" - return original_getenv(key, default) - with mock.patch( "snowflake.cli.api.cli_global_context._CliGlobalContextAccess.connection", new_callable=PropertyMock, ) as connection: connection.return_value.role = "resolved_role" connection.return_value.warehouse = "resolved_warehouse" - with mock.patch("os.getenv", side_effect=mock_getenv): - # TODO: probably a better way of going about this is to not generate - # a definition structure for these values but directly return defaults - # in "getter" functions (higher-level data structures). - local = generate_local_override_yml(project) - assert local.native_app.application.name == "minimal_jsmith" - assert local.native_app.application.role == "resolved_role" - assert local.native_app.application.warehouse == "resolved_warehouse" - assert local.native_app.application.debug == True - assert local.native_app.package.name == "minimal_pkg_jsmith" - assert local.native_app.package.role == "resolved_role" + + # TODO: probably a better way of going about this is to not generate + # a definition structure for these values but directly return defaults + # in "getter" functions (higher-level data structures). + local = generate_local_override_yml(project) + assert local.native_app.application.name == "minimal_jsmith" + assert local.native_app.application.role == "resolved_role" + assert local.native_app.application.warehouse == "resolved_warehouse" + assert local.native_app.application.debug == True + assert local.native_app.package.name == "minimal_pkg_jsmith" + assert local.native_app.package.role == "resolved_role" @pytest.mark.parametrize("project_definition_files", ["underspecified"], indirect=True) diff --git a/tests/project/test_util.py b/tests/project/test_util.py index 6849b7998..4269f623a 100644 --- a/tests/project/test_util.py +++ b/tests/project/test_util.py @@ -17,13 +17,16 @@ import pytest from snowflake.cli.api.project.util import ( append_to_identifier, + concat_identifiers, escape_like_pattern, + identifier_to_str, is_valid_identifier, is_valid_object_name, is_valid_quoted_identifier, is_valid_string_literal, is_valid_unquoted_identifier, to_identifier, + to_quoted_identifier, to_string_literal, ) @@ -231,3 +234,78 @@ def test_to_string_literal(raw_string, literal): ) def test_escape_like_pattern(raw_string, escaped): assert escape_like_pattern(raw_string) == escaped + + +@pytest.mark.parametrize( + "input_value, expected_value", + [ + # valid unquoted id -> return quoted + ("Id_1", '"Id_1"'), + # valid quoted id without special chars -> return the same + ('"Id_1"', '"Id_1"'), + # valid quoted id with special chars -> return the same + ('"Id_""_._end"', '"Id_""_._end"'), + # unquoted with unsafe chars -> return quoted + ('Id_""_._end', '"Id_""""_._end"'), + # looks like quoted identifier but not properly escaped -> requote + ('"Id"end"', '"""Id""end"""'), + # blank -> quoted + ("", '""'), + ], +) +def test_to_quoted_identifier(input_value, expected_value): + assert to_quoted_identifier(input_value) == expected_value + + +@pytest.mark.parametrize( + "id1, id2, concatenated_value", + [ + # both unquoted, no special char -> result unquoted + ("Id_1", "_Id_2", "Id_1_Id_2"), + # both unquoted, one with special char -> result quoted + ('Id_1."', "_Id_2", '"Id_1.""_Id_2"'), + # both unquoted, one with special char -> result quoted + ("Id_1", '_Id_2."', '"Id_1_Id_2."""'), + # one quoted, no special chars -> result quoted + ('"Id_1"', "_Id_2", '"Id_1_Id_2"'), + # one quoted, no special chars -> result quoted + ("Id_1", '"_Id_2"', '"Id_1_Id_2"'), + # both quoted, no special chars -> result quoted + ('"Id_1"', '"_Id_2"', '"Id_1_Id_2"'), + # quoted with valid 2 double quotes within -> result quoted + ('"Id_""_1"', '"_""_Id_2"', '"Id_""_1_""_Id_2"'), + # quoted with invalid single double quotes within + # -> result quoted, and original quotes escaped + ('"Id_"_1"', '"_"_Id_2"', '"""Id_""_1""""_""_Id_2"""'), + # one quoted with invalid single double quotes within and other properly quoted + # -> result quoted, double quotes escaped + ('"Id_"_1"', '"_Id_2"', '"""Id_""_1""_Id_2"'), + # one quoted with escaped double quotes within + # another non quoted with double quotes within + # -> result is quoted, and proper escaping of non quoted + ('"Id_""_1"', '_Id_"_2', '"Id_""_1_Id_""_2"'), + # 2 blanks -> result should be quoted to be a valid identifier + ("", "", '""'), + ], +) +def test_concat_identifiers(id1, id2, concatenated_value): + assert concat_identifiers([id1, id2]) == concatenated_value + + +@pytest.mark.parametrize( + "identifier, expected_value", + [ + # valid unquoted id -> return same + ("Id_1", "Id_1"), + # valid quoted id without special chars -> return unquoted + ('"Id_1"', "Id_1"), + # valid quoted id with special chars -> return unquoted and unescaped " + ('"Id_""_._end"', 'Id_"_._end'), + # unquoted with unsafe chars -> return the same without unescaping + ('Id_""_._end', 'Id_""_._end'), + # blank identifier -> turns into blank string + ('""', ""), + ], +) +def test_identifier_to_str(identifier, expected_value): + assert identifier_to_str(identifier) == expected_value