Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post_deploy script should run relative to project root, Fixes #1325. #1340

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/snowflake/cli/api/rendering/sql_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Dict
from typing import Dict, Optional

from click import ClickException
from jinja2 import StrictUndefined, loaders
Expand All @@ -29,11 +29,11 @@
_SQL_TEMPLATE_END = "}"


def get_sql_cli_jinja_env():
def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None):
_random_block = "___very___unique___block___to___disable___logic___blocks___"
return env_bootstrap(
IgnoreAttrEnvironment(
loader=loaders.BaseLoader(),
loader=loader or loaders.BaseLoader(),
keep_trailing_newline=True,
variable_start_string=_SQL_TEMPLATE_START,
variable_end_string=_SQL_TEMPLATE_END,
Expand Down
22 changes: 15 additions & 7 deletions src/snowflake/cli/plugins/nativeapp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from textwrap import dedent
from typing import Optional

import jinja2
from click.exceptions import ClickException
Expand Down Expand Up @@ -54,18 +57,23 @@ def __init__(self, item: str, expected_owner: str, actual_owner: str):
)


class MissingPackageScriptError(ClickException):
"""A referenced package script was not found."""
class MissingScriptError(ClickException):
"""A referenced script was not found."""

def __init__(self, relpath: str):
super().__init__(f'Package script "{relpath}" does not exist')
super().__init__(f'Script "{relpath}" does not exist')


class InvalidPackageScriptError(ClickException):
"""A referenced package script had syntax error(s)."""
class InvalidScriptError(ClickException):
"""A referenced script had syntax error(s)."""

def __init__(self, relpath: str, err: jinja2.TemplateError):
super().__init__(f'Package script "{relpath}" is not a valid jinja2 template')
def __init__(
self, relpath: str, err: jinja2.TemplateError, lineno: Optional[int] = None
):
lineno_str = f":{lineno}" if lineno is not None else ""
super().__init__(
f'Script "{relpath}{lineno_str}" does not contain a valid template: {err.message}'
)
self.err = err
sfc-gh-cgorrie marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
54 changes: 36 additions & 18 deletions src/snowflake/cli/plugins/nativeapp/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import cached_property
from pathlib import Path
from textwrap import dedent
from typing import List, Optional, TypedDict
from typing import Any, List, Optional, TypedDict

import jinja2
from click import ClickException
Expand Down Expand Up @@ -62,8 +62,8 @@
from snowflake.cli.plugins.nativeapp.exceptions import (
ApplicationPackageAlreadyExistsError,
ApplicationPackageDoesNotExistError,
InvalidPackageScriptError,
MissingPackageScriptError,
InvalidScriptError,
MissingScriptError,
SetupScriptFailedValidation,
UnexpectedOwnerError,
)
Expand Down Expand Up @@ -561,6 +561,36 @@ def create_app_package(self) -> None:
)
)

def _expand_script_templates(
self, env: jinja2.Environment, jinja_context: dict[str, Any], scripts: List[str]
) -> List[str]:
"""
Input:
- env: Jinja2 environment
- jinja_context: a dictionary with the jinja context
- scripts: list of scripts that need to be expanded with Jinja
Returns:
- List of expanded scripts content.
Size of the return list is the same as the size of the input scripts list.
"""
scripts_contents = []
for relpath in scripts:
try:
template = env.get_template(relpath)
result = template.render(**jinja_context)
scripts_contents.append(result)

except jinja2.TemplateNotFound as e:
raise MissingScriptError(e.name) from e

except jinja2.TemplateSyntaxError as e:
raise InvalidScriptError(e.name, e, e.lineno) from e

except jinja2.UndefinedError as e:
raise InvalidScriptError(relpath, e) from e

return scripts_contents

def _apply_package_scripts(self) -> None:
"""
Assuming the application package exists and we are using the correct role,
Expand All @@ -572,21 +602,9 @@ def _apply_package_scripts(self) -> None:
undefined=jinja2.StrictUndefined,
)

queued_queries = []
for relpath in self.package_scripts:
try:
template = env.get_template(relpath)
result = template.render(dict(package_name=self.package_name))
queued_queries.append(result)

except jinja2.TemplateNotFound as e:
raise MissingPackageScriptError(e.name)

except jinja2.TemplateSyntaxError as e:
raise InvalidPackageScriptError(e.name, e)

except jinja2.UndefinedError as e:
raise InvalidPackageScriptError(relpath, e)
queued_queries = self._expand_script_templates(
env, dict(package_name=self.package_name), self.package_scripts
)

# once we're sure all the templates expanded correctly, execute all of them
with self.use_package_warehouse():
Expand Down
42 changes: 29 additions & 13 deletions src/snowflake/cli/plugins/nativeapp/run_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from textwrap import dedent
from typing import Optional

import jinja2
import typer
from click import UsageError
from snowflake.cli.api.cli_global_context import cli_context
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.errno import (
APPLICATION_NO_LONGER_AVAILABLE,
Expand All @@ -35,7 +37,9 @@
identifier_to_show_like_pattern,
unquote_identifier,
)
from snowflake.cli.api.rendering.sql_templates import snowflake_sql_jinja_render
from snowflake.cli.api.rendering.sql_templates import (
get_sql_cli_jinja_env,
)
from snowflake.cli.api.utils.cursor import find_all_rows
from snowflake.cli.plugins.nativeapp.artifacts import BundleMap
from snowflake.cli.plugins.nativeapp.constants import (
Expand Down Expand Up @@ -135,36 +139,48 @@ class NativeAppRunProcessor(NativeAppManager, NativeAppCommandProcessor):
def __init__(self, project_definition: NativeApp, project_root: Path):
super().__init__(project_definition, project_root)

def _execute_sql_script(self, sql_script_path):
def _execute_sql_script(
self, script_content: str, database_name: Optional[str] = None
):
"""
Executing the SQL script in the provided file path after expanding template variables.
Executing the provided SQL script content.
This assumes that a relevant warehouse is already active.
Consequently, "use database" will be executed first if it is set in definition file or in the current connection.
If database_name is passed in, it will be used first.
"""
with open(sql_script_path) as f:
sql_script = f.read()

try:
if self._conn.database:
self._execute_query(f"use database {self._conn.database}")
sql_script = snowflake_sql_jinja_render(content=sql_script)
self._execute_queries(sql_script)
if database_name is not None:
self._execute_query(f"use database {database_name}")

self._execute_queries(script_content)
except ProgrammingError as err:
generic_sql_error_handler(err)

def _execute_post_deploy_hooks(self):
post_deploy_script_hooks = self.app_post_deploy_hooks
if post_deploy_script_hooks:
with cc.phase("Executing application post-deploy actions"):
sql_scripts_paths = []
for hook in post_deploy_script_hooks:
if hook.sql_script:
cc.step(f"Executing SQL script: {hook.sql_script}")
self._execute_sql_script(hook.sql_script)
sql_scripts_paths.append(hook.sql_script)
else:
raise ValueError(
f"Unsupported application post-deploy hook type: {hook}"
)

env = get_sql_cli_jinja_env(
loader=jinja2.loaders.FileSystemLoader(self.project_root)
)
scripts_content_list = self._expand_script_templates(
env, cli_context.template_context, sql_scripts_paths
)

for index, sql_script_path in enumerate(sql_scripts_paths):
cc.step(f"Executing SQL script: {sql_script_path}")
self._execute_sql_script(
scripts_content_list[index], self._conn.database
)

def get_all_existing_versions(self) -> SnowflakeCursor:
"""
Get all existing versions, if defined, for an application package.
Expand Down
10 changes: 5 additions & 5 deletions tests/nativeapp/test_package_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
)
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.plugins.nativeapp.exceptions import (
InvalidPackageScriptError,
MissingPackageScriptError,
InvalidScriptError,
MissingScriptError,
)
from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor
from snowflake.connector import ProgrammingError
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_package_scripts_without_conn_info_succeeds(
def test_missing_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(MissingPackageScriptError):
with pytest.raises(MissingScriptError):
(working_dir / "002-shared.sql").unlink()
native_app_manager._apply_package_scripts() # noqa: SLF001

Expand All @@ -211,7 +211,7 @@ def test_missing_package_script(mock_execute, project_definition_files):
def test_invalid_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(InvalidPackageScriptError):
with pytest.raises(InvalidScriptError):
second_file = working_dir / "002-shared.sql"
second_file.unlink()
second_file.write_text("select * from {{ package_name")
Expand All @@ -226,7 +226,7 @@ def test_invalid_package_script(mock_execute, project_definition_files):
def test_undefined_var_package_script(mock_execute, project_definition_files):
working_dir: Path = project_definition_files[0].parent
native_app_manager = _get_na_manager(str(working_dir))
with pytest.raises(InvalidPackageScriptError):
with pytest.raises(InvalidScriptError):
second_file = working_dir / "001-shared.sql"
second_file.unlink()
second_file.write_text("select * from {{ abc }}")
Expand Down
3 changes: 2 additions & 1 deletion tests/nativeapp/test_post_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from snowflake.cli.api.project.schemas.native_app.application import (
ApplicationPostDeployHook,
)
from snowflake.cli.plugins.nativeapp.exceptions import MissingScriptError
from snowflake.cli.plugins.nativeapp.run_processor import NativeAppRunProcessor

from tests.nativeapp.patch_utils import mock_connection
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_missing_sql_script(
with project_directory("napp_post_deploy_missing_file") as project_dir:
processor = _get_run_processor(str(project_dir))

with pytest.raises(FileNotFoundError) as err:
with pytest.raises(MissingScriptError) as err:
processor._execute_post_deploy_hooks() # noqa SLF001


Expand Down
38 changes: 24 additions & 14 deletions tests_integration/nativeapp/test_init_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
not_contains_row_with,
row_from_snowflake_session,
)
from tests_integration.testing_utils.working_directory_utils import (
WorkingDirectoryChanger,
)

USER_NAME = f"user_{uuid.uuid4().hex}"
TEST_ENV = generate_user_env(USER_NAME)
Expand Down Expand Up @@ -423,30 +426,37 @@ def test_nativeapp_init_from_repo_with_single_template(
# Tests that application post-deploy scripts are executed by creating a post_deploy_log table and having each post-deploy script add a record to it
@pytest.mark.integration
@pytest.mark.parametrize("is_versioned", [True, False])
@pytest.mark.parametrize("with_project_flag", [True, False])
def test_nativeapp_app_post_deploy(
runner, snowflake_session, project_directory, is_versioned
runner, snowflake_session, project_directory, is_versioned, with_project_flag
):
version = "v1"
project_name = "myapp"
app_name = f"{project_name}_{USER_NAME}"

def run():
"""(maybe) create a version, then snow app run"""
if is_versioned:
with project_directory("napp_application_post_deploy") as tmp_dir:
version_run_args = ["--version", version] if is_versioned else []
project_args = ["--project", f"{tmp_dir}"] if with_project_flag else []

def run():
"""(maybe) create a version, then snow app run"""
if is_versioned:
result = runner.invoke_with_connection_json(
["app", "version", "create", version] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

result = runner.invoke_with_connection_json(
["app", "version", "create", version],
["app", "run"] + version_run_args + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

run_args = ["--version", version] if is_versioned else []
result = runner.invoke_with_connection_json(
["app", "run"] + run_args,
env=TEST_ENV,
)
assert result.exit_code == 0
if with_project_flag:
working_directory_changer = WorkingDirectoryChanger()
working_directory_changer.change_working_directory_to("app")

with project_directory("napp_application_post_deploy") as tmp_dir:
try:
# First run, application is created (and maybe a version)
run()
Expand Down Expand Up @@ -480,13 +490,13 @@ def run():
# need to drop the version before we can teardown
if is_versioned:
result = runner.invoke_with_connection_json(
["app", "version", "drop", version, "--force"],
["app", "version", "drop", version, "--force"] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0

result = runner.invoke_with_connection_json(
["app", "teardown", "--force"],
["app", "teardown", "--force"] + project_args,
env=TEST_ENV,
)
assert result.exit_code == 0
Expand Down
Loading