Skip to content

Commit

Permalink
Add support for defaults in project definition file templating (#1330)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-melnacouzi committed Jul 31, 2024
1 parent 6f1db14 commit 06a9ec3
Show file tree
Hide file tree
Showing 23 changed files with 1,460 additions and 187 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ but should be replaced with `snow init`
* Added connection option `--token-file-path` allowing passing OAuth token using a file. The function is also
supported by setting `token_file_path` in connection definition.
* Support for Python remote execution via `snow stage execute` and `snow git execute` similar to existing EXECUTE IMMEDIATE support.
* Added support for project definition file defaults in templates
* Added support for autocomplete in `--connection` flag.
* Added `snow init` command, which supports initializing projects with external templates.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Literal, Optional

from pydantic import AliasChoices, Field
from pydantic import Field
from snowflake.cli.api.project.schemas.entities.application_package_entity import (
ApplicationPackageEntity,
)
Expand All @@ -25,26 +25,20 @@
TargetField,
)
from snowflake.cli.api.project.schemas.updatable_model import (
UpdatableModel,
DiscriminatorField,
)


class ApplicationEntity(EntityBase):
type: Literal["application"] # noqa: A003
type: Literal["application"] = DiscriminatorField() # noqa A003
name: str = Field(
title="Name of the application created when this entity is deployed"
)
from_: ApplicationFromField = Field(
validation_alias=AliasChoices("from"),
from_: TargetField[ApplicationPackageEntity] = Field(
alias="from",
title="An application package this entity should be created from",
)
debug: Optional[bool] = Field(
title="Whether to enable debug mode when using a named stage to create an application object",
default=None,
)


class ApplicationFromField(UpdatableModel):
target: TargetField[ApplicationPackageEntity] = Field(
title="Reference to an application package entity",
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
)
from snowflake.cli.api.project.schemas.native_app.package import DistributionOptions
from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
from snowflake.cli.api.project.schemas.updatable_model import IdentifierField
from snowflake.cli.api.project.schemas.updatable_model import (
DiscriminatorField,
IdentifierField,
)


class ApplicationPackageEntity(EntityBase):
type: Literal["application package"] # noqa: A003
type: Literal["application package"] = DiscriminatorField() # noqa: A003
name: str = Field(
title="Name of the application package created when this entity is deployed"
)
Expand Down
33 changes: 13 additions & 20 deletions src/snowflake/cli/api/project/schemas/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from abc import ABC
from typing import Generic, List, Optional, TypeVar

from pydantic import AliasChoices, Field, GetCoreSchemaHandler, ValidationInfo
from pydantic_core import core_schema
from pydantic import Field
from snowflake.cli.api.project.schemas.native_app.application import (
ApplicationPostDeployHook,
)
Expand All @@ -45,7 +44,7 @@ class MetaField(UpdatableModel):
class DefaultsField(UpdatableModel):
schema_: Optional[str] = Field(
title="Schema.",
validation_alias=AliasChoices("schema"),
alias="schema",
default=None,
)
stage: Optional[str] = Field(
Expand All @@ -65,21 +64,15 @@ def get_type(cls) -> str:
TargetType = TypeVar("TargetType")


class TargetField(Generic[TargetType]):
def __init__(self, entity_target_key: str):
self.value = entity_target_key

def __repr__(self):
return self.value

@classmethod
def validate(cls, value: str, info: ValidationInfo) -> TargetField:
return cls(value)
class TargetField(UpdatableModel, Generic[TargetType]):
target: str = Field(
title="Reference to a target entity",
)

@classmethod
def __get_pydantic_core_schema__(
cls, source_type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.with_info_after_validator_function(
cls.validate, handler(str), field_name=handler.field_name
)
def get_type(self) -> type:
"""
Returns the generic type of this class, indicating the entity type.
Pydantic extracts Generic annotations, and populates
them in __pydantic_generic_metadata__
"""
return self.__pydantic_generic_metadata__["args"][0]
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ class Application(UpdatableModel):
title="Actions that will be executed after the application object is created/upgraded",
default=None,
)


class ApplicationV11(Application):
# Templated defaults only supported in v1.1+
name: Optional[str] = Field(
title="Name of the application object created when you run the snow app run command",
default="<% fn.id_concat(ctx.native_app.name, '_', fn.clean_id(fn.get_username('unknown_user'))) %>",
)
15 changes: 13 additions & 2 deletions src/snowflake/cli/api/project/schemas/native_app/native_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from typing import List, Optional, Union

from pydantic import Field, field_validator
from snowflake.cli.api.project.schemas.native_app.application import Application
from snowflake.cli.api.project.schemas.native_app.package import Package
from snowflake.cli.api.project.schemas.native_app.application import (
Application,
ApplicationV11,
)
from snowflake.cli.api.project.schemas.native_app.package import Package, PackageV11
from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping
from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
from snowflake.cli.api.project.util import (
Expand Down Expand Up @@ -80,3 +83,11 @@ def transform_artifacts(
transformed_artifacts.append(PathMapping(src=artifact))

return transformed_artifacts


class NativeAppV11(NativeApp):
# templated defaults are only supported with version 1.1+
package: Optional[PackageV11] = Field(title="PackageSchema", default=PackageV11())
application: Optional[ApplicationV11] = Field(
title="Application info", default=ApplicationV11()
)
8 changes: 8 additions & 0 deletions src/snowflake/cli/api/project/schemas/native_app/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ def validate_scripts(cls, input_list):
"package.scripts field should contain unique values. Check the list for duplicates and try again"
)
return input_list


class PackageV11(Package):
# Templated defaults only supported in v1.1+
name: Optional[str] = IdentifierField(
title="Name of the application package created when you run the snow app run command",
default="<% fn.id_concat(ctx.native_app.name, '_pkg_', fn.clean_id(fn.get_username('unknown_user'))) %>",
)
63 changes: 23 additions & 40 deletions src/snowflake/cli/api/project/schemas/project_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
Entity,
v2_entity_types_map,
)
from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp
from snowflake.cli.api.project.schemas.native_app.native_app import (
NativeApp,
NativeAppV11,
)
from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark
from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel
from snowflake.cli.api.utils.models import ProjectEnvironment
from snowflake.cli.api.utils.types import Context
from typing_extensions import Annotated

Expand Down Expand Up @@ -99,22 +101,14 @@ class DefinitionV10(_ProjectDefinitionBase):


class DefinitionV11(DefinitionV10):
env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
title="Environment specification for this project.",
native_app: Optional[NativeAppV11] = Field(
title="Native app definitions for the project", default=None
)
env: Optional[Dict[str, Union[str, int, bool]]] = Field(
title="Default environment specification for this project.",
default=None,
validation_alias="env",
union_mode="smart",
)

@field_validator("env")
@classmethod
def _convert_env(
cls, env: Union[Dict, ProjectEnvironment, None]
) -> ProjectEnvironment:
if isinstance(env, ProjectEnvironment):
return env
return ProjectEnvironment(default_env=(env or {}), override_env={})


class DefinitionV20(_ProjectDefinitionBase):
entities: Dict[str, Annotated[Entity, Field(discriminator="type")]] = Field(
Expand Down Expand Up @@ -147,10 +141,10 @@ def validate_entities(cls, entities: Dict[str, Entity]) -> Dict[str, Entity]:
for key, entity in entities.items():
# TODO Automatically detect TargetFields to validate
if entity.type == ApplicationEntity.get_type():
if isinstance(entity.from_.target, TargetField):
target_key = str(entity.from_.target)
target_class = entity.from_.__class__.model_fields["target"]
target_type = target_class.annotation.__args__[0]
if isinstance(entity.from_, TargetField):
target_key = entity.from_.target
target_object = entity.from_
target_type = target_object.get_type()
cls._validate_target_field(target_key, target_type, entities)
return entities

Expand All @@ -160,37 +154,26 @@ def _validate_target_field(
):
if target_key not in entities:
raise ValueError(f"No such target: {target_key}")
else:
# Validate the target type
actual_target_type = entities[target_key].__class__
if target_type and target_type is not actual_target_type:
raise ValueError(
f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
)

# Validate the target type
actual_target_type = entities[target_key].__class__
if target_type and target_type is not actual_target_type:
raise ValueError(
f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
)

defaults: Optional[DefaultsField] = Field(
title="Default key/value entity values that are merged recursively for each entity.",
default=None,
)

env: Union[Dict[str, str], ProjectEnvironment, None] = Field(
title="Environment specification for this project.",
env: Optional[Dict[str, Union[str, int, bool]]] = Field(
title="Default environment specification for this project.",
default=None,
validation_alias="env",
union_mode="smart",
)

@field_validator("env")
@classmethod
def _convert_env(
cls, env: Union[Dict, ProjectEnvironment, None]
) -> ProjectEnvironment:
if isinstance(env, ProjectEnvironment):
return env
return ProjectEnvironment(default_env=(env or {}), override_env={})


def build_project_definition(**data):
def build_project_definition(**data) -> ProjectDefinition:
"""
Returns a ProjectDefinition instance with a version matching the provided definition_version value
"""
Expand Down
4 changes: 1 addition & 3 deletions src/snowflake/cli/api/project/schemas/snowpark/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from pydantic import Field, field_validator
from snowflake.cli.api.project.schemas.identifier_model import ObjectIdentifierModel
from snowflake.cli.api.project.schemas.snowpark.argument import Argument
from snowflake.cli.api.project.schemas.updatable_model import (
UpdatableModel,
)
from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel


class _CallableBase(UpdatableModel):
Expand Down
Loading

0 comments on commit 06a9ec3

Please sign in to comment.