Skip to content

Commit

Permalink
Improve messaging for templates processor (#1521)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-melnacouzi committed Sep 9, 2024
1 parent ee0e00d commit d3de8bb
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,8 @@ def collect_extension_functions(
predicate=is_python_file_artifact,
)
):
cc.step(
"Processing Snowpark annotations from {}".format(
dest_file.relative_to(bundle_map.deploy_root())
)
)
src_file_name = src_file.relative_to(self._bundle_ctx.project_root)
cc.step(f"Processing Snowpark annotations from {src_file_name}")
collected_extension_function_json = _execute_in_sandbox(
py_file=str(dest_file.resolve()),
deploy_root=self._bundle_ctx.deploy_root,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from pathlib import Path
from typing import Optional

import jinja2
Expand All @@ -30,27 +31,72 @@
)
from snowflake.cli.api.rendering.project_definition_templates import (
get_client_side_jinja_env,
has_client_side_templates,
)
from snowflake.cli.api.rendering.sql_templates import (
choose_sql_jinja_env_based_on_template_syntax,
has_sql_templates,
)


def _is_sql_file(file: Path) -> bool:
return file.name.lower().endswith(".sql")


class TemplatesProcessor(ArtifactProcessor):
"""
Processor class to perform template expansion on all relevant artifacts (specified in the project definition file).
"""

def expand_templates_in_file(self, src: Path, dest: Path) -> None:
"""
Expand templates in the file.
"""
if src.is_dir():
return

with self.edit_file(dest) as file:
if not has_client_side_templates(file.contents) and not (
_is_sql_file(dest) and has_sql_templates(file.contents)
):
return

src_file_name = src.relative_to(self._bundle_ctx.project_root)
cc.step(f"Expanding templates in {src_file_name}")
with cc.indented():
try:
jinja_env = (
choose_sql_jinja_env_based_on_template_syntax(
file.contents, reference_name=src_file_name
)
if _is_sql_file(dest)
else get_client_side_jinja_env()
)
expanded_template = jinja_env.from_string(file.contents).render(
get_cli_context().template_context
)

# For now, we are printing the source file path in the error message
# instead of the destination file path to make it easier for the user
# to identify the file that has the error, and edit the correct file.
except jinja2.TemplateSyntaxError as e:
raise InvalidTemplateInFileError(src_file_name, e, e.lineno) from e

except jinja2.UndefinedError as e:
raise InvalidTemplateInFileError(src_file_name, e) from e

if expanded_template != file.contents:
file.edited_contents = expanded_template

def process(
self,
artifact_to_process: PathMapping,
processor_mapping: Optional[ProcessorMapping],
**kwargs,
):
) -> None:
"""
Process the artifact by executing the template expansion logic on it.
"""
cc.step(f"Processing artifact {artifact_to_process} with templates processor")

bundle_map = BundleMap(
project_root=self._bundle_ctx.project_root,
Expand All @@ -62,32 +108,4 @@ def process(
absolute=True,
expand_directories=True,
):
if src.is_dir():
continue
with self.edit_file(dest) as f:
file_name = src.relative_to(self._bundle_ctx.project_root)

jinja_env = (
choose_sql_jinja_env_based_on_template_syntax(
f.contents, reference_name=file_name
)
if dest.name.lower().endswith(".sql")
else get_client_side_jinja_env()
)

try:
expanded_template = jinja_env.from_string(f.contents).render(
get_cli_context().template_context
)

# For now, we are printing the source file path in the error message
# instead of the destination file path to make it easier for the user
# to identify the file that has the error, and edit the correct file.
except jinja2.TemplateSyntaxError as e:
raise InvalidTemplateInFileError(file_name, e, e.lineno) from e

except jinja2.UndefinedError as e:
raise InvalidTemplateInFileError(file_name, e) from e

if expanded_template != f.contents:
f.edited_contents = expanded_template
self.expand_templates_in_file(src, dest)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
_YML_TEMPLATE_END = "%>"


def has_client_side_templates(template_content: str) -> bool:
return _YML_TEMPLATE_START in template_content


def get_client_side_jinja_env() -> Environment:
_random_block = "___very___unique___block___to___disable___logic___blocks___"
return env_bootstrap(
Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/cli/api/rendering/sql_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def _does_template_have_env_syntax(env: Environment, template_content: str) -> b
return bool(meta.find_undeclared_variables(template))


def has_sql_templates(template_content: str) -> bool:
return (
_OLD_SQL_TEMPLATE_START in template_content
or _SQL_TEMPLATE_START in template_content
)


def choose_sql_jinja_env_based_on_template_syntax(
template_content: str, reference_name: Optional[str] = None
) -> Environment:
Expand Down
30 changes: 29 additions & 1 deletion tests/api/test_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
import pytest
from click import ClickException
from jinja2 import UndefinedError
from snowflake.cli.api.rendering.sql_templates import snowflake_sql_jinja_render
from snowflake.cli.api.rendering.project_definition_templates import (
has_client_side_templates,
)
from snowflake.cli.api.rendering.sql_templates import (
has_sql_templates,
snowflake_sql_jinja_render,
)
from snowflake.cli.api.utils.models import ProjectEnvironment


Expand Down Expand Up @@ -118,3 +124,25 @@ def test_contex_can_access_environment_variable(cli_context):
assert snowflake_sql_jinja_render("&{ ctx.env.TEST_ENV_VAR }") == os.environ.get(
"TEST_ENV_VAR"
)


def test_has_sql_templates():
assert has_sql_templates("abc <% %> abc")
assert has_sql_templates("abc <% abc")
assert has_sql_templates("abc &{ foo } abc")
assert has_sql_templates("abc &{ abc")
assert not has_sql_templates("SELECT 1")
assert not has_sql_templates("<test>")
assert not has_sql_templates("{<est}")
assert not has_sql_templates("")


def test_has_client_side_templates():
assert has_client_side_templates("abc <% %> abc")
assert has_client_side_templates("abc <% abc")
assert not has_client_side_templates("abc &{ foo } abc")
assert not has_client_side_templates("abc &{ abc")
assert not has_client_side_templates("SELECT 1")
assert not has_client_side_templates("<test>")
assert not has_client_side_templates("{<est}")
assert not has_client_side_templates("")

0 comments on commit d3de8bb

Please sign in to comment.