Skip to content

Commit

Permalink
Remove defaults and fix mixins behavior (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Sep 11, 2024
1 parent b586b8e commit be1929b
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 112 deletions.
3 changes: 3 additions & 0 deletions src/snowflake/cli/api/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def using_context(self) -> "FQN":
from snowflake.cli.api.cli_global_context import get_cli_context

return self.using_connection(get_cli_context().connection)

def to_dict(self) -> dict:
return {"name": self.name, "schema": self.schema, "database": self.database}
12 changes: 0 additions & 12 deletions src/snowflake/cli/api/project/schemas/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,6 @@ def ensure_use_mixins_is_a_list(
return mixins


class DefaultsField(UpdatableModel):
schema_: Optional[str] = Field(
title="Schema.",
alias="schema",
default=None,
)
stage: Optional[str] = Field(
title="Stage.",
default=None,
)


class EntityModelBase(ABC, UpdatableModel):
@classmethod
def get_type(cls) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/cli/api/project/schemas/identifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

class Identifier(UpdatableModel):
name: str = Field(title="Entity name")
schema_: str = Field(title="Entity schema", alias="schema", default=None)
database: str = Field(title="Entity database", default=None)
schema_: Optional[str] = Field(title="Entity schema", alias="schema", default=None)
database: Optional[str] = Field(title="Entity database", default=None)


class ObjectIdentifierBaseModel:
Expand Down
123 changes: 84 additions & 39 deletions src/snowflake/cli/api/project/schemas/project_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ApplicationEntityModel,
)
from snowflake.cli.api.project.schemas.entities.common import (
DefaultsField,
TargetField,
)
from snowflake.cli.api.project.schemas.entities.entities import (
Expand All @@ -42,6 +41,7 @@
from typing_extensions import Annotated

AnnotatedEntity = Annotated[EntityModel, Field(discriminator="type")]
scalar = str | int | float | bool


@dataclass
Expand Down Expand Up @@ -114,31 +114,12 @@ class DefinitionV11(DefinitionV10):
class DefinitionV20(_ProjectDefinitionBase):
entities: Dict[str, AnnotatedEntity] = Field(title="Entity definitions.")

@model_validator(mode="before")
@classmethod
def apply_defaults(cls, data: Dict) -> Dict:
"""
Applies default values that exist on the model but not specified in yml
"""
if "defaults" in data and "entities" in data:
for key, entity in data["entities"].items():
entity_fields = get_allowed_fields_for_entity(entity)
if not entity_fields:
continue
for default_key, default_value in data["defaults"].items():
if default_key in entity_fields and default_key not in entity:
entity[default_key] = default_value
return data

@field_validator("entities", mode="after")
@classmethod
def validate_entities_identifiers(
cls, entities: Dict[str, EntityModel]
) -> Dict[str, EntityModel]:
for key, entity in entities.items():
@model_validator(mode="after")
def validate_entities_identifiers(self):
for key, entity in self.entities.items():
entity.set_entity_id(key)
entity.validate_identifier()
return entities
return self

@field_validator("entities", mode="after")
@classmethod
Expand Down Expand Up @@ -179,11 +160,6 @@ def _validate_target_field(
f"Target type mismatch. Expected {target_type.__name__}, got {actual_target_type.__name__}"
)

defaults: Optional[DefaultsField] = Field(
title="Default key/value entity values that are merged recursively for each entity.",
default=None,
)

env: Optional[Dict[str, Union[str, int, bool]]] = Field(
title="Default environment specification for this project.",
default=None,
Expand All @@ -203,22 +179,83 @@ def apply_mixins(cls, data: Dict) -> Dict:
if "mixins" not in data or "entities" not in data:
return data

for entity in data["entities"].values():
entities = data["entities"]
for entity_name, entity in entities.items():
entity_mixins = entity_mixins_to_list(
entity.get("meta", {}).get("use_mixins")
)

entity_fields = get_allowed_fields_for_entity(entity)
if entity_fields and entity_mixins:
for mixin_name in entity_mixins:
if mixin_name in data["mixins"]:
for key, value in data["mixins"][mixin_name].items():
if key in entity_fields:
entity[key] = value
else:
raise ValueError(f"Mixin {mixin_name} not found in mixins")
merged_values = cls._merge_mixins_with_entity(
entity_id=entity_name,
entity=entity,
entity_mixins_names=entity_mixins,
mixin_defs=data["mixins"],
)
entities[entity_name] = merged_values
return data

@classmethod
def _merge_mixins_with_entity(
cls, entity_id: str, entity: dict, entity_mixins_names: list, mixin_defs: dict
) -> dict:
# Validate mixins
for mixin_name in entity_mixins_names:
if mixin_name not in mixin_defs:
raise ValueError(f"Mixin {mixin_name} not defined")

# Build object override data from mixins
data: dict = {}
for mx_name in entity_mixins_names:
data = cls._merge_data(data, mixin_defs[mx_name])

for key, override_value in data.items():
if key not in get_allowed_fields_for_entity(entity):
raise ValueError(
f"Unsupported key '{key}' for entity {entity_id} of type {entity['type']} "
)

entity_value = entity.get(key)
if entity_value is not None and not isinstance(
entity_value, type(override_value)
):
raise ValueError(
f"Value from mixins for property {key} is of type '{type(override_value).__name__}' "
f"while entity {entity_id} expects value of type '{type(entity_value).__name__}'"
)

# Apply entity data on top of mixins
data = cls._merge_data(data, entity)
return data

@classmethod
def _merge_data(
cls,
left: dict | list | scalar | None,
right: dict | list | scalar | None,
):
"""
Merges right data into left. Right and left is expected to be of the same type, if not right is returned.
If left is sequence then missing elements from right are appended.
If left is dictionary then we update it with data from right. The update is done recursively key by key.
"""
if left is None:
return right

# At that point left and right are of the same type
if isinstance(left, dict) and isinstance(right, dict):
data = dict(left)
for key in right:
data[key] = cls._merge_data(left=data.get(key), right=right[key])
return data

if isinstance(left, list) and isinstance(right, list):
return _unique_extend(left, right)

if not isinstance(right, type(left)):
raise ValueError(f"Could not merge {type(right)} and {type(left)}.")

return right

def get_entities_by_type(self, entity_type: str):
return {i: e for i, e in self.entities.items() if e.get_type() == entity_type}

Expand Down Expand Up @@ -268,3 +305,11 @@ def get_allowed_fields_for_entity(entity: Dict[str, Any]) -> List[str]:

entity_model = v2_entity_model_types_map[entity_type]
return entity_model.model_fields


def _unique_extend(list_a: List, list_b: List) -> List:
new_list = list(list_a)
for item in list_b:
if item not in list_a:
new_list.append(item)
return new_list
Loading

0 comments on commit be1929b

Please sign in to comment.