diff --git a/src/snowflake/cli/api/rendering/sql_templates.py b/src/snowflake/cli/api/rendering/sql_templates.py index 01f99f5fc..86008214f 100644 --- a/src/snowflake/cli/api/rendering/sql_templates.py +++ b/src/snowflake/cli/api/rendering/sql_templates.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict +from typing import Dict, Optional from click import ClickException from jinja2 import StrictUndefined, loaders @@ -29,11 +29,11 @@ _SQL_TEMPLATE_END = "}" -def get_sql_cli_jinja_env(): +def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None): _random_block = "___very___unique___block___to___disable___logic___blocks___" return env_bootstrap( IgnoreAttrEnvironment( - loader=loaders.BaseLoader(), + loader=loader or loaders.BaseLoader(), keep_trailing_newline=True, variable_start_string=_SQL_TEMPLATE_START, variable_end_string=_SQL_TEMPLATE_END, diff --git a/src/snowflake/cli/plugins/nativeapp/exceptions.py b/src/snowflake/cli/plugins/nativeapp/exceptions.py index 8cefd4133..1db80097b 100644 --- a/src/snowflake/cli/plugins/nativeapp/exceptions.py +++ b/src/snowflake/cli/plugins/nativeapp/exceptions.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from textwrap import dedent +from typing import Optional import jinja2 from click.exceptions import ClickException @@ -54,18 +57,23 @@ def __init__(self, item: str, expected_owner: str, actual_owner: str): ) -class MissingPackageScriptError(ClickException): - """A referenced package script was not found.""" +class MissingScriptError(ClickException): + """A referenced script was not found.""" def __init__(self, relpath: str): - super().__init__(f'Package script "{relpath}" does not exist') + super().__init__(f'Script "{relpath}" does not exist') -class InvalidPackageScriptError(ClickException): - """A referenced package script had syntax error(s).""" +class InvalidScriptError(ClickException): + """A referenced script had syntax error(s).""" - def __init__(self, relpath: str, err: jinja2.TemplateError): - super().__init__(f'Package script "{relpath}" is not a valid jinja2 template') + def __init__( + self, relpath: str, err: jinja2.TemplateError, lineno: Optional[int] = None + ): + lineno_str = f":{lineno}" if lineno is not None else "" + super().__init__( + f'Script "{relpath}{lineno_str}" does not contain a valid template: {err.message}' + ) self.err = err diff --git a/src/snowflake/cli/plugins/nativeapp/manager.py b/src/snowflake/cli/plugins/nativeapp/manager.py index 69cd65cb3..af726311d 100644 --- a/src/snowflake/cli/plugins/nativeapp/manager.py +++ b/src/snowflake/cli/plugins/nativeapp/manager.py @@ -21,7 +21,7 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import List, Optional, TypedDict +from typing import Any, List, Optional, TypedDict import jinja2 from click import ClickException @@ -62,8 +62,8 @@ from snowflake.cli.plugins.nativeapp.exceptions import ( ApplicationPackageAlreadyExistsError, ApplicationPackageDoesNotExistError, - InvalidPackageScriptError, - MissingPackageScriptError, + InvalidScriptError, + MissingScriptError, SetupScriptFailedValidation, UnexpectedOwnerError, ) @@ -561,6 +561,36 @@ def create_app_package(self) -> None: ) ) + def _expand_script_templates( + self, env: jinja2.Environment, jinja_context: dict[str, Any], scripts: List[str] + ) -> List[str]: + """ + Input: + - env: Jinja2 environment + - jinja_context: a dictionary with the jinja context + - scripts: list of scripts that need to be expanded with Jinja + Returns: + - List of expanded scripts content. + Size of the return list is the same as the size of the input scripts list. + """ + scripts_contents = [] + for relpath in scripts: + try: + template = env.get_template(relpath) + result = template.render(**jinja_context) + scripts_contents.append(result) + + except jinja2.TemplateNotFound as e: + raise MissingScriptError(e.name) from e + + except jinja2.TemplateSyntaxError as e: + raise InvalidScriptError(e.name, e, e.lineno) from e + + except jinja2.UndefinedError as e: + raise InvalidScriptError(relpath, e) from e + + return scripts_contents + def _apply_package_scripts(self) -> None: """ Assuming the application package exists and we are using the correct role, @@ -572,21 +602,9 @@ def _apply_package_scripts(self) -> None: undefined=jinja2.StrictUndefined, ) - queued_queries = [] - for relpath in self.package_scripts: - try: - template = env.get_template(relpath) - result = template.render(dict(package_name=self.package_name)) - queued_queries.append(result) - - except jinja2.TemplateNotFound as e: - raise MissingPackageScriptError(e.name) - - except jinja2.TemplateSyntaxError as e: - raise InvalidPackageScriptError(e.name, e) - - except jinja2.UndefinedError as e: - raise InvalidPackageScriptError(relpath, e) + queued_queries = self._expand_script_templates( + env, dict(package_name=self.package_name), self.package_scripts + ) # once we're sure all the templates expanded correctly, execute all of them with self.use_package_warehouse(): diff --git a/src/snowflake/cli/plugins/nativeapp/run_processor.py b/src/snowflake/cli/plugins/nativeapp/run_processor.py index 92e620ca2..9a804c1dd 100644 --- a/src/snowflake/cli/plugins/nativeapp/run_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/run_processor.py @@ -18,8 +18,10 @@ from textwrap import dedent from typing import Optional +import jinja2 import typer from click import UsageError +from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.errno import ( APPLICATION_NO_LONGER_AVAILABLE, @@ -35,7 +37,9 @@ identifier_to_show_like_pattern, unquote_identifier, ) -from snowflake.cli.api.rendering.sql_templates import snowflake_sql_jinja_render +from snowflake.cli.api.rendering.sql_templates import ( + get_sql_cli_jinja_env, +) from snowflake.cli.api.utils.cursor import find_all_rows from snowflake.cli.plugins.nativeapp.artifacts import BundleMap from snowflake.cli.plugins.nativeapp.constants import ( @@ -135,20 +139,19 @@ class NativeAppRunProcessor(NativeAppManager, NativeAppCommandProcessor): def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__(project_definition, project_root) - def _execute_sql_script(self, sql_script_path): + def _execute_sql_script( + self, script_content: str, database_name: Optional[str] = None + ): """ - Executing the SQL script in the provided file path after expanding template variables. + Executing the provided SQL script content. This assumes that a relevant warehouse is already active. - Consequently, "use database" will be executed first if it is set in definition file or in the current connection. + If database_name is passed in, it will be used first. """ - with open(sql_script_path) as f: - sql_script = f.read() - try: - if self._conn.database: - self._execute_query(f"use database {self._conn.database}") - sql_script = snowflake_sql_jinja_render(content=sql_script) - self._execute_queries(sql_script) + if database_name is not None: + self._execute_query(f"use database {database_name}") + + self._execute_queries(script_content) except ProgrammingError as err: generic_sql_error_handler(err) @@ -156,15 +159,28 @@ def _execute_post_deploy_hooks(self): post_deploy_script_hooks = self.app_post_deploy_hooks if post_deploy_script_hooks: with cc.phase("Executing application post-deploy actions"): + sql_scripts_paths = [] for hook in post_deploy_script_hooks: if hook.sql_script: - cc.step(f"Executing SQL script: {hook.sql_script}") - self._execute_sql_script(hook.sql_script) + sql_scripts_paths.append(hook.sql_script) else: raise ValueError( f"Unsupported application post-deploy hook type: {hook}" ) + env = get_sql_cli_jinja_env( + loader=jinja2.loaders.FileSystemLoader(self.project_root) + ) + scripts_content_list = self._expand_script_templates( + env, cli_context.template_context, sql_scripts_paths + ) + + for index, sql_script_path in enumerate(sql_scripts_paths): + cc.step(f"Executing SQL script: {sql_script_path}") + self._execute_sql_script( + scripts_content_list[index], self._conn.database + ) + def get_all_existing_versions(self) -> SnowflakeCursor: """ Get all existing versions, if defined, for an application package. diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index 3fea55243..74935c320 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -24,8 +24,8 @@ ) from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.plugins.nativeapp.exceptions import ( - InvalidPackageScriptError, - MissingPackageScriptError, + InvalidScriptError, + MissingScriptError, ) from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor from snowflake.connector import ProgrammingError @@ -198,7 +198,7 @@ def test_package_scripts_without_conn_info_succeeds( def test_missing_package_script(mock_execute, project_definition_files): working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) - with pytest.raises(MissingPackageScriptError): + with pytest.raises(MissingScriptError): (working_dir / "002-shared.sql").unlink() native_app_manager._apply_package_scripts() # noqa: SLF001 @@ -211,7 +211,7 @@ def test_missing_package_script(mock_execute, project_definition_files): def test_invalid_package_script(mock_execute, project_definition_files): working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) - with pytest.raises(InvalidPackageScriptError): + with pytest.raises(InvalidScriptError): second_file = working_dir / "002-shared.sql" second_file.unlink() second_file.write_text("select * from {{ package_name") @@ -226,7 +226,7 @@ def test_invalid_package_script(mock_execute, project_definition_files): def test_undefined_var_package_script(mock_execute, project_definition_files): working_dir: Path = project_definition_files[0].parent native_app_manager = _get_na_manager(str(working_dir)) - with pytest.raises(InvalidPackageScriptError): + with pytest.raises(InvalidScriptError): second_file = working_dir / "001-shared.sql" second_file.unlink() second_file.write_text("select * from {{ abc }}") diff --git a/tests/nativeapp/test_post_deploy.py b/tests/nativeapp/test_post_deploy.py index c3bd9db7b..288059105 100644 --- a/tests/nativeapp/test_post_deploy.py +++ b/tests/nativeapp/test_post_deploy.py @@ -22,6 +22,7 @@ from snowflake.cli.api.project.schemas.native_app.application import ( ApplicationPostDeployHook, ) +from snowflake.cli.plugins.nativeapp.exceptions import MissingScriptError from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor from tests.nativeapp.patch_utils import mock_connection @@ -139,7 +140,7 @@ def test_missing_sql_script( with project_directory("napp_post_deploy_missing_file") as project_dir: processor = _get_run_processor(str(project_dir)) - with pytest.raises(FileNotFoundError) as err: + with pytest.raises(MissingScriptError) as err: processor._execute_post_deploy_hooks() # noqa SLF001 diff --git a/tests_integration/nativeapp/test_init_run.py b/tests_integration/nativeapp/test_init_run.py index c5b6cb0a2..203117d1b 100644 --- a/tests_integration/nativeapp/test_init_run.py +++ b/tests_integration/nativeapp/test_init_run.py @@ -26,6 +26,9 @@ not_contains_row_with, row_from_snowflake_session, ) +from tests_integration.testing_utils.working_directory_utils import ( + WorkingDirectoryChanger, +) USER_NAME = f"user_{uuid.uuid4().hex}" TEST_ENV = generate_user_env(USER_NAME) @@ -423,30 +426,37 @@ def test_nativeapp_init_from_repo_with_single_template( # Tests that application post-deploy scripts are executed by creating a post_deploy_log table and having each post-deploy script add a record to it @pytest.mark.integration @pytest.mark.parametrize("is_versioned", [True, False]) +@pytest.mark.parametrize("with_project_flag", [True, False]) def test_nativeapp_app_post_deploy( - runner, snowflake_session, project_directory, is_versioned + runner, snowflake_session, project_directory, is_versioned, with_project_flag ): version = "v1" project_name = "myapp" app_name = f"{project_name}_{USER_NAME}" - def run(): - """(maybe) create a version, then snow app run""" - if is_versioned: + with project_directory("napp_application_post_deploy") as tmp_dir: + version_run_args = ["--version", version] if is_versioned else [] + project_args = ["--project", f"{tmp_dir}"] if with_project_flag else [] + + def run(): + """(maybe) create a version, then snow app run""" + if is_versioned: + result = runner.invoke_with_connection_json( + ["app", "version", "create", version] + project_args, + env=TEST_ENV, + ) + assert result.exit_code == 0 + result = runner.invoke_with_connection_json( - ["app", "version", "create", version], + ["app", "run"] + version_run_args + project_args, env=TEST_ENV, ) assert result.exit_code == 0 - run_args = ["--version", version] if is_versioned else [] - result = runner.invoke_with_connection_json( - ["app", "run"] + run_args, - env=TEST_ENV, - ) - assert result.exit_code == 0 + if with_project_flag: + working_directory_changer = WorkingDirectoryChanger() + working_directory_changer.change_working_directory_to("app") - with project_directory("napp_application_post_deploy") as tmp_dir: try: # First run, application is created (and maybe a version) run() @@ -480,13 +490,13 @@ def run(): # need to drop the version before we can teardown if is_versioned: result = runner.invoke_with_connection_json( - ["app", "version", "drop", version, "--force"], + ["app", "version", "drop", version, "--force"] + project_args, env=TEST_ENV, ) assert result.exit_code == 0 result = runner.invoke_with_connection_json( - ["app", "teardown", "--force"], + ["app", "teardown", "--force"] + project_args, env=TEST_ENV, ) assert result.exit_code == 0