Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support project definition V2 in streamlit deploy command (#1369) #1394

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from snowflake.cli.api.project.schemas.entities.application_package_entity import (
ApplicationPackageEntity,
)
from snowflake.cli.api.project.schemas.entities.streamlit_entity import StreamlitEntity

Entity = Union[ApplicationEntity, ApplicationPackageEntity]
Entity = Union[ApplicationEntity, ApplicationPackageEntity, StreamlitEntity]

ALL_ENTITIES = [*get_args(Entity)]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 Snowflake Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from pathlib import Path
from typing import List, Literal, Optional

from pydantic import Field, model_validator
from snowflake.cli.api.project.schemas.entities.common import EntityBase
from snowflake.cli.api.project.schemas.identifier_model import ObjectIdentifierModel
from snowflake.cli.api.project.schemas.updatable_model import (
DiscriminatorField,
)


class StreamlitEntity(EntityBase, ObjectIdentifierModel(object_name="Streamlit")): # type: ignore
type: Literal["streamlit"] = DiscriminatorField() # noqa: A003
title: Optional[str] = Field(
title="Human-readable title for the Streamlit dashboard", default=None
)
query_warehouse: str = Field(
title="Snowflake warehouse to host the app", default=None
)
main_file: Optional[str] = Field(
title="Entrypoint file of the Streamlit app", default="streamlit_app.py"
)
pages_dir: Optional[str] = Field(title="Streamlit pages", default=None)
stage: Optional[str] = Field(
title="Stage in which the app’s artifacts will be stored", default="streamlit"
)
# Possibly can be PathMapping
artifacts: Optional[List[Path]] = Field(
title="List of files which should be deployed. Each file needs to exist locally. "
"Main file needs to be included in the artifacts.",
default=None,
)

@model_validator(mode="after")
def main_file_must_be_in_artifacts(self):
if not self.artifacts:
return self

if Path(self.main_file) not in self.artifacts:
raise ValueError(
f"Specified main file {self.main_file} is not included in artifacts."
)
return self

@model_validator(mode="after")
def artifacts_must_exists(self):
if not self.artifacts:
return self

for artifact in self.artifacts:
if not artifact.exists():
raise ValueError(
f"Specified artifact {artifact} does not exist locally."
)

return self
3 changes: 3 additions & 0 deletions src/snowflake/cli/api/project/schemas/project_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def _validate_target_field(
default=None,
)

def get_entities_by_type(self, entity_type: str):
return {i: e for i, e in self.entities.items() if e.get_type() == entity_type}


def build_project_definition(**data) -> ProjectDefinition:
"""
Expand Down
101 changes: 79 additions & 22 deletions src/snowflake/cli/plugins/streamlit/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from pathlib import Path
from typing import Dict

import click
import typer
Expand All @@ -28,14 +29,18 @@
from snowflake.cli.api.commands.flags import ReplaceOption, like_option
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.exceptions import NoProjectDefinitionError
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.output.types import (
CommandResult,
MessageResult,
SingleQueryResult,
)
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit
from snowflake.cli.api.project.schemas.entities.streamlit_entity import StreamlitEntity
from snowflake.cli.api.project.schemas.project_definition import (
ProjectDefinition,
ProjectDefinitionV2,
)
from snowflake.cli.plugins.object.command_aliases import (
add_object_command_aliases,
scope_option,
Expand Down Expand Up @@ -129,37 +134,42 @@ def streamlit_deploy(
stage is used. If the specified stage does not exist, the command creates it.
"""

assert_project_type("streamlit")
cli_context = get_cli_context()
pd = cli_context.project_definition
if not pd.meets_version_requirement("2"):
if not pd.streamlit:
raise NoProjectDefinitionError(
project_type="streamlit", project_file=cli_context.project_root
)
pd = _migrate_v1_streamlit_to_v2(pd)

streamlits: Dict[str, StreamlitEntity] = pd.get_entities_by_type(
entity_type="streamlit"
)

streamlit: Streamlit = get_cli_context().project_definition.streamlit
if not streamlit:
return MessageResult("No streamlit were specified in project definition.")
if not streamlits:
raise NoProjectDefinitionError(
project_type="streamlit", project_file=cli_context.project_root
)

environment_file = streamlit.env_file
if environment_file and not Path(environment_file).exists():
raise ClickException(f"Provided file {environment_file} does not exist")
elif environment_file is None:
environment_file = "environment.yml"

pages_dir = streamlit.pages_dir
if pages_dir and not Path(pages_dir).exists():
raise ClickException(f"Provided file {pages_dir} does not exist")
elif pages_dir is None:
pages_dir = "pages"
# TODO: fix in follow-up
if len(list(streamlits)) > 1:
raise ClickException(
"Currently only single streamlit entity per project is supported."
)

# Get first streamlit
streamlit: StreamlitEntity = streamlits[list(streamlits)[0]]
streamlit_id = FQN.from_identifier_model(streamlit).using_context()

url = StreamlitManager().deploy(
streamlit_id=streamlit_id,
environment_file=Path(environment_file),
pages_dir=Path(pages_dir),
artifacts=streamlit.artifacts,
stage_name=streamlit.stage,
main_file=Path(streamlit.main_file),
main_file=streamlit.main_file,
replace=replace,
query_warehouse=streamlit.query_warehouse,
additional_source_files=streamlit.additional_source_files,
title=streamlit.title,
**options,
)

if open_:
Expand All @@ -168,6 +178,53 @@ def streamlit_deploy(
return MessageResult(f"Streamlit successfully deployed and available under {url}")


def _migrate_v1_streamlit_to_v2(pd: ProjectDefinition):
default_env_file = "environment.yml"
default_pages_dir = "pages"

# Process env file
environment_file = pd.streamlit.env_file
if environment_file and not Path(environment_file).exists():
raise ClickException(f"Provided file {environment_file} does not exist")
elif environment_file is None and Path(default_env_file).exists():
environment_file = default_env_file
# Process pages dir
pages_dir = pd.streamlit.pages_dir
if pages_dir and not Path(pages_dir).exists():
raise ClickException(f"Provided file {pages_dir} does not exist")
elif pages_dir is None and Path(default_pages_dir).exists():
pages_dir = default_pages_dir

# Build V2 definition
artifacts = [
pd.streamlit.main_file,
environment_file,
pages_dir,
]
artifacts = [a for a in artifacts if a is not None]
if pd.streamlit.additional_source_files:
artifacts.extend(pd.streamlit.additional_source_files)

data = {
"definition_version": "2",
"entities": {
"streamlit_app": {
"type": "streamlit",
"name": pd.streamlit.name,
"schema": pd.streamlit.schema_name,
"database": pd.streamlit.database,
"title": pd.streamlit.title,
"query_warehouse": pd.streamlit.query_warehouse,
"main_file": str(pd.streamlit.main_file),
"pages_dir": str(pd.streamlit.pages_dir),
"stage": pd.streamlit.stage,
"artifacts": artifacts,
}
},
}
return ProjectDefinitionV2(**data)


@app.command("get-url", requires_connection=True)
def get_url(
name: FQN = StreamlitNameArgument,
Expand Down
57 changes: 18 additions & 39 deletions src/snowflake/cli/plugins/streamlit/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -46,33 +45,25 @@ def share(self, streamlit_name: FQN, to_role: str) -> SnowflakeCursor:
def _put_streamlit_files(
self,
root_location: str,
main_file: Path,
environment_file: Optional[Path],
pages_dir: Optional[Path],
additional_source_files: Optional[List[Path]],
artifacts: Optional[List[Path]] = None,
):
if not artifacts:
return
stage_manager = StageManager()

stage_manager.put(main_file, root_location, 4, True)

if environment_file and environment_file.exists():
stage_manager.put(environment_file, root_location, 4, True)

if pages_dir and pages_dir.exists():
stage_manager.put(pages_dir / "*.py", f"{root_location}/pages", 4, True)

if additional_source_files:
for file in additional_source_files:
if os.sep in str(file):
destination = f"{root_location}/{str(file.parent)}"
else:
destination = root_location
stage_manager.put(file, destination, 4, True)
for file in artifacts:
if file.is_dir():
stage_manager.put(
f"{file.joinpath('*')}", f"{root_location}/{file}", 4, True
)
elif len(file.parts) > 1:
stage_manager.put(file, f"{root_location}/{file.parent}", 4, True)
else:
stage_manager.put(file, root_location, 4, True)

def _create_streamlit(
self,
streamlit_id: FQN,
main_file: Path,
main_file: str,
replace: Optional[bool] = None,
experimental: Optional[bool] = None,
query_warehouse: Optional[str] = None,
Expand All @@ -96,7 +87,7 @@ def _create_streamlit(
if from_stage_name:
query.append(f"ROOT_LOCATION = '{from_stage_name}'")

query.append(f"MAIN_FILE = '{main_file.name}'")
query.append(f"MAIN_FILE = '{main_file}'")

if query_warehouse:
query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
Expand All @@ -108,15 +99,12 @@ def _create_streamlit(
def deploy(
self,
streamlit_id: FQN,
main_file: Path,
environment_file: Optional[Path] = None,
pages_dir: Optional[Path] = None,
main_file: str,
artifacts: Optional[List[Path]] = None,
stage_name: Optional[str] = None,
query_warehouse: Optional[str] = None,
replace: Optional[bool] = False,
additional_source_files: Optional[List[Path]] = None,
title: Optional[str] = None,
**options,
):
# for backwards compatibility - quoted stage path might be case-sensitive
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers
Expand Down Expand Up @@ -169,10 +157,7 @@ def deploy(

self._put_streamlit_files(
root_location,
main_file,
environment_file,
pages_dir,
additional_source_files,
artifacts,
)
else:
"""
Expand All @@ -191,13 +176,7 @@ def deploy(
f"{stage_name}/{streamlit_name_for_root_location}"
)

self._put_streamlit_files(
root_location,
main_file,
environment_file,
pages_dir,
additional_source_files,
)
self._put_streamlit_files(root_location, artifacts)

self._create_streamlit(
streamlit_id,
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions tests/app/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def test_executing_command_sends_telemetry_result_data(

@mock.patch("snowflake.connector.connect")
@mock.patch("snowflake.cli.plugins.streamlit.commands.StreamlitManager")
@mock.patch.dict(
os.environ, {"SNOWFLAKE_CLI_FEATURES_ENABLE_PROJECT_DEFINITION_V2": "true"}
)
def test_executing_command_sends_project_definition_in_telemetry_data(
_, mock_conn, project_directory, runner
):
Expand Down
31 changes: 31 additions & 0 deletions tests/streamlit/__snapshots__/test_commands.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# serializer version: 1
# name: test_artifacts_must_exists
'''
+- Error ----------------------------------------------------------------------+
| During evaluation of DefinitionV20 in project definition following errors |
| were encountered: |
| For field entities.my_streamlit.streamlit you provided '{'artifacts': |
| ['streamlit_app.py', 'foo_bar.py', 'pages/', 'environment.yml'], |
| 'main_file': 'streamlit_app.py', 'name': 'test_streamlit_deploy_snowcli', |
| 'query_warehouse': 'xsmall', 'stage': 'streamlit', 'title': 'My Fancy |
| Streamlit', 'type': 'streamlit'}'. This caused: Value error, Specified |
| artifact foo_bar.py does not exist locally. |
+------------------------------------------------------------------------------+

'''
# ---
# name: test_main_file_must_be_in_artifacts
'''
+- Error ----------------------------------------------------------------------+
| During evaluation of DefinitionV20 in project definition following errors |
| were encountered: |
| For field entities.my_streamlit.streamlit you provided '{'artifacts': |
| ['streamlit_app.py', 'utils/utils.py', 'pages/', 'environment.yml'], |
| 'main_file': 'foo_bar.py', 'name': 'test_streamlit_deploy_snowcli', |
| 'query_warehouse': 'xsmall', 'stage': 'streamlit', 'title': 'My Fancy |
| Streamlit', 'type': 'streamlit'}'. This caused: Value error, Specified main |
| file foo_bar.py is not included in artifacts. |
+------------------------------------------------------------------------------+

'''
# ---
Loading
Loading