Skip to content

Commit

Permalink
Support Arbitrary Catalog IDs on Athena Data Source (#7059)
Browse files Browse the repository at this point in the history
Co-authored-by: SeongTae Jeong <seongtaejg@gmail.com>
  • Loading branch information
dtaniwaki and lucydodo authored Jul 24, 2024
1 parent 80f7ba1 commit c244e75
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 5 deletions.
22 changes: 17 additions & 5 deletions redash/query_runner/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def configuration_schema(cls):
"default": "default",
},
"glue": {"type": "boolean", "title": "Use Glue Data Catalog"},
"catalog_ids": {
"type": "string",
"title": "Enter Glue Data Catalog IDs, separated by commas (leave blank for default catalog)",
},
"work_group": {
"type": "string",
"title": "Athena Work Group",
Expand All @@ -88,7 +92,7 @@ def configuration_schema(cls):
},
},
"required": ["region", "s3_staging_dir"],
"extra_options": ["glue", "cost_per_tb"],
"extra_options": ["glue", "catalog_ids", "cost_per_tb"],
"order": [
"region",
"s3_staging_dir",
Expand Down Expand Up @@ -172,16 +176,23 @@ def _get_iam_credentials(self, user=None):
"region_name": self.configuration["region"],
}

def __get_schema_from_glue(self):
def __get_schema_from_glue(self, catalog_id=""):
client = boto3.client("glue", **self._get_iam_credentials())
schema = {}

database_paginator = client.get_paginator("get_databases")
table_paginator = client.get_paginator("get_tables")

for databases in database_paginator.paginate():
databases_iterator = database_paginator.paginate(
**({"CatalogId": catalog_id} if catalog_id != "" else {}),
)

for databases in databases_iterator:
for database in databases["DatabaseList"]:
iterator = table_paginator.paginate(DatabaseName=database["Name"])
iterator = table_paginator.paginate(
DatabaseName=database["Name"],
**({"CatalogId": catalog_id} if catalog_id != "" else {}),
)
for table in iterator.search("TableList[]"):
table_name = "%s.%s" % (database["Name"], table["Name"])
if "StorageDescriptor" not in table:
Expand All @@ -196,7 +207,8 @@ def __get_schema_from_glue(self):

def get_schema(self, get_stats=False):
if self.configuration.get("glue", False):
return self.__get_schema_from_glue()
catalog_ids = [id.strip() for id in self.configuration.get("catalog_ids", "").split(",")]
return sum([self.__get_schema_from_glue(catalog_id) for catalog_id in catalog_ids], [])

schema = {}
query = """
Expand Down
94 changes: 94 additions & 0 deletions tests/query_runner/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,97 @@ def test_no_storage_descriptor_table(self):
)
with self.stubber:
assert query_runner.get_schema() == []

def test_multi_catalog_tables(self):
"""Tables of multi-catalogs"""
query_runner = Athena({"glue": True, "region": "mars-east-1", "catalog_ids": "foo,bar"})

self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test1"}]}, {"CatalogId": "foo"})
self.stubber.add_response(
"get_tables",
{
"TableList": [
{
"Name": "jdbc_table",
"StorageDescriptor": {
"Columns": [{"Name": "row_id", "Type": "int"}],
"Location": "Database.Schema.Table",
"Compressed": False,
"NumberOfBuckets": -1,
"SerdeInfo": {"Parameters": {}},
"BucketColumns": [],
"SortColumns": [],
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
"StoredAsSubDirectories": False,
},
"PartitionKeys": [],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
}
]
},
{"CatalogId": "foo", "DatabaseName": "test1"},
)
self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test2"}]}, {"CatalogId": "bar"})
self.stubber.add_response(
"get_tables",
{
"TableList": [
{
"Name": "jdbc_table",
"StorageDescriptor": {
"Columns": [{"Name": "row_id", "Type": "int"}],
"Location": "Database.Schema.Table",
"Compressed": False,
"NumberOfBuckets": -1,
"SerdeInfo": {"Parameters": {}},
"BucketColumns": [],
"SortColumns": [],
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
"StoredAsSubDirectories": False,
},
"PartitionKeys": [],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
}
]
},
{"CatalogId": "bar", "DatabaseName": "test2"},
)
with self.stubber:
assert query_runner.get_schema() == [
{"columns": ["row_id"], "name": "test1.jdbc_table"},
{"columns": ["row_id"], "name": "test2.jdbc_table"},
]

0 comments on commit c244e75

Please sign in to comment.