diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
index 7b95b6d0f3492..1489f23a13a06 100644
--- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
+++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
@@ -47,9 +47,10 @@ jest.mock(
{column.name}
),
);
-const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*';
+const getTableMetadataEndpoint =
+ /\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/;
const getExtraTableMetadataEndpoint =
- 'glob:**/api/v1/database/*/table_metadata/extra/*';
+ /\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/;
const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded';
beforeEach(() => {
diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
index c964fc32faaf0..b3f8aec8f99a0 100644
--- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
+++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
@@ -75,7 +75,7 @@ const DatasetPanelWrapper = ({
setLoading(true);
setHasColumns?.(false);
const path = schema
- ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/`
+ ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}`
: `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`;
try {
const response = await SupersetClient.get({
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index e48f0b9bd8eca..719d5af588852 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -329,7 +329,7 @@ def short_data(self) -> dict[str, Any]:
"edit_url": self.url,
"id": self.id,
"uid": self.uid,
- "schema": self.schema,
+ "schema": self.schema or None,
"name": self.name,
"type": self.type,
"connection": self.connection,
@@ -383,7 +383,7 @@ def data(self) -> dict[str, Any]:
"datasource_name": self.datasource_name,
"table_name": self.datasource_name,
"type": self.type,
- "schema": self.schema,
+ "schema": self.schema or None,
"offset": self.offset,
"cache_timeout": self.cache_timeout,
"params": self.params,
@@ -1263,7 +1263,7 @@ def link(self) -> Markup:
def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
- return security_manager.get_schema_perm(self.database, self.schema)
+ return security_manager.get_schema_perm(self.database, self.schema or None)
def get_perm(self) -> str:
"""
@@ -1320,7 +1320,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
- table=Table(self.table_name, self.schema, self.catalog),
+ table=Table(self.table_name, self.schema or None, self.catalog),
normalize_columns=self.normalize_columns,
)
@@ -1336,7 +1336,7 @@ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
- Table(self.table_name, self.schema, self.catalog),
+ Table(self.table_name, self.schema or None, self.catalog),
show_cols=False,
latest_partition=False,
)
@@ -1528,7 +1528,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
col_desc = get_columns_description(
self.database,
self.catalog,
- self.schema,
+ self.schema or None,
sql,
)
if not col_desc:
@@ -1735,7 +1735,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
return df
try:
- df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
+ df = self.database.get_df(
+ sql, self.schema or None, mutator=assign_column_label
+ )
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
# bubble up to the Flask error handler so they are returned as proper SIP-40
@@ -1772,7 +1774,7 @@ def get_sqla_table_object(self) -> Table:
return self.database.get_table(
Table(
self.table_name,
- self.schema,
+ self.schema or None,
self.catalog,
)
)
@@ -1790,7 +1792,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
for metric in self.database.get_metrics(
Table(
self.table_name,
- self.schema,
+ self.schema or None,
self.catalog,
)
)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index f458b62165143..3cc1315129571 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -639,11 +639,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
return driver in cls.drivers
@classmethod
- def get_default_schema(cls, database: Database) -> str | None:
+ def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema in a given database.
"""
- with database.get_inspector() as inspector:
+ with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name
@classmethod
@@ -698,7 +698,7 @@ def get_default_schema_for_query(
return schema
# return the default schema of the database
- return cls.get_default_schema(database)
+ return cls.get_default_schema(database, query.catalog)
@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 3245cdca4b73b..08a38894e6645 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -40,12 +40,12 @@
)
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.models.core import Database
- from superset.sql_parse import Table
with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
@@ -96,8 +96,11 @@ def get_extra_table_metadata(
),
}
- if database.has_view_by_name(table.table, table.schema):
- with database.get_inspector() as inspector:
+ if database.has_view(Table(table.table, table.schema)):
+ with database.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
metadata["view"] = inspector.get_view_definition(
table.table,
table.schema,
diff --git a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py
index 5fa35fb963b2e..ec5733e151044 100644
--- a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py
+++ b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py
@@ -17,7 +17,7 @@
"""Add catalog column
Revision ID: 5f57af97bc3f
-Revises: 5ad7321c2169
+Revises: d60591c5515f
Create Date: 2024-04-11 15:41:34.663989
"""
@@ -27,7 +27,7 @@
# revision identifiers, used by Alembic.
revision = "5f57af97bc3f"
-down_revision = "5ad7321c2169"
+down_revision = "d60591c5515f"
def upgrade():
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 359547118e9b0..c10d589d97fd9 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -774,13 +774,13 @@ def test_create_dataset_validate_tables_exists(self):
@patch("superset.models.core.Database.get_columns")
@patch("superset.models.core.Database.has_table")
- @patch("superset.models.core.Database.has_view_by_name")
+ @patch("superset.models.core.Database.has_view")
@patch("superset.models.core.Database.get_table")
def test_create_dataset_validate_view_exists(
self,
mock_get_table,
mock_has_table,
- mock_has_view_by_name,
+ mock_has_view,
mock_get_columns,
):
"""
@@ -797,7 +797,7 @@ def test_create_dataset_validate_view_exists(
]
mock_has_table.return_value = False
- mock_has_view_by_name.return_value = True
+ mock_has_view.return_value = True
mock_get_table.return_value = None
example_db = get_example_database()
diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
index d7498dc4fee84..c8db1f912ad21 100644
--- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
@@ -30,7 +30,7 @@
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
@@ -238,7 +238,7 @@ def test_get_table_names(self):
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
- sqla_table = example_db.get_table("energy_usage")
+ sqla_table = example_db.get_table(Table("energy_usage"))
dialect = example_db.get_dialect()
# TODO: fix column type conversion for presto.
@@ -540,8 +540,7 @@ def test_get_indexes():
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
- table_name="bar",
- schema="foo",
+ table=Table("bar", "foo"),
)
== indexes
)
diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py
index ce184685db540..53f9137076bb8 100644
--- a/tests/integration_tests/db_engine_specs/bigquery_tests.py
+++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py
@@ -165,8 +165,7 @@ def test_get_indexes(self):
BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
)
== []
)
@@ -184,8 +183,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
) == [
{
"name": "partition",
@@ -207,8 +205,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
) == [
{
"name": "partition",
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py
index 39d2c30fd1162..4d1a84508167b 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -23,7 +23,7 @@
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
-from superset.sql_parse import Table, ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from tests.integration_tests.test_app import app
@@ -328,7 +328,10 @@ def test_where_latest_partition(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", database, select(), columns
+ database,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@@ -341,7 +344,10 @@ def test_where_latest_partition_super_method_exception(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", database, select(), columns
+ database,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
assert result is None
mock_method.assert_called()
@@ -353,7 +359,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method):
db = mock.Mock()
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select()
+ db,
+ Table("test_table", "test_schema"),
+ select(),
)
assert result is None
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 3f7bc52a57d2a..607afa6953fcd 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -82,7 +82,7 @@ def verify_presto_column(self, column, expected_results):
row = mock.Mock()
row.Column, row.Type, row.Null = column
inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row])
- results = PrestoEngineSpec.get_columns(inspector, "", "")
+ results = PrestoEngineSpec.get_columns(inspector, Table("", ""))
self.assertEqual(len(expected_results), len(results))
for expected_result, result in zip(expected_results, results):
self.assertEqual(expected_result[0], result["column_name"])
@@ -573,7 +573,10 @@ def test_presto_where_latest_partition(self):
db.get_df = mock.Mock(return_value=df)
columns = [{"name": "ds"}, {"name": "hour"}]
result = PrestoEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select(), columns
+ db,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
@@ -802,7 +805,7 @@ def test_show_columns(self):
return_value=["a", "b"]
)
table_name = "table_name"
- result = PrestoEngineSpec._show_columns(inspector, table_name, None)
+ result = PrestoEngineSpec._show_columns(inspector, Table(table_name))
assert result == ["a", "b"]
inspector.bind.execute.assert_called_once_with(
f'SHOW COLUMNS FROM "{table_name}"'
@@ -818,7 +821,7 @@ def test_show_columns_with_schema(self):
)
table_name = "table_name"
schema = "schema"
- result = PrestoEngineSpec._show_columns(inspector, table_name, schema)
+ result = PrestoEngineSpec._show_columns(inspector, Table(table_name, schema))
assert result == ["a", "b"]
inspector.bind.execute.assert_called_once_with(
f'SHOW COLUMNS FROM "{schema}"."{table_name}"'
@@ -848,7 +851,14 @@ def test_select_star_no_presto_expand_data(self, mock_select_star):
]
PrestoEngineSpec.select_star(database, Table(table_name), engine, cols=cols)
mock_select_star.assert_called_once_with(
- database, table_name, engine, None, 100, False, True, True, cols
+ database,
+ Table(table_name),
+ engine,
+ 100,
+ False,
+ True,
+ True,
+ cols,
)
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
@@ -877,9 +887,8 @@ def test_select_star_presto_expand_data(
)
mock_select_star.assert_called_once_with(
database,
- table_name,
+ Table(table_name),
engine,
- None,
100,
True,
True,
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index d7aeaf1c5f036..5bd83828ed2c6 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -311,15 +311,15 @@ def test_convert_dttm(
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
-def test_get_extra_table_metadata() -> None:
+def test_get_extra_table_metadata(mocker: MockerFixture) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
- db_mock = Mock()
+ db_mock = mocker.MagicMock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
- db_mock.has_view_by_name = Mock(return_value=None)
+ db_mock.has_view = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.get_extra_table_metadata(
db_mock,