diff --git a/redash/query_runner/athena.py b/redash/query_runner/athena.py index acde734f02..0d70a2a0e3 100644 --- a/redash/query_runner/athena.py +++ b/redash/query_runner/athena.py @@ -76,6 +76,10 @@ def configuration_schema(cls): "default": "default", }, "glue": {"type": "boolean", "title": "Use Glue Data Catalog"}, + "catalog_ids": { + "type": "string", + "title": "Enter Glue Data Catalog IDs, separated by commas (leave blank for default catalog)", + }, "work_group": { "type": "string", "title": "Athena Work Group", @@ -88,7 +92,7 @@ def configuration_schema(cls): }, }, "required": ["region", "s3_staging_dir"], - "extra_options": ["glue", "cost_per_tb"], + "extra_options": ["glue", "catalog_ids", "cost_per_tb"], "order": [ "region", "s3_staging_dir", @@ -172,16 +176,23 @@ def _get_iam_credentials(self, user=None): "region_name": self.configuration["region"], } - def __get_schema_from_glue(self): + def __get_schema_from_glue(self, catalog_id=""): client = boto3.client("glue", **self._get_iam_credentials()) schema = {} database_paginator = client.get_paginator("get_databases") table_paginator = client.get_paginator("get_tables") - for databases in database_paginator.paginate(): + databases_iterator = database_paginator.paginate( + **({"CatalogId": catalog_id} if catalog_id != "" else {}), + ) + + for databases in databases_iterator: for database in databases["DatabaseList"]: - iterator = table_paginator.paginate(DatabaseName=database["Name"]) + iterator = table_paginator.paginate( + DatabaseName=database["Name"], + **({"CatalogId": catalog_id} if catalog_id != "" else {}), + ) for table in iterator.search("TableList[]"): table_name = "%s.%s" % (database["Name"], table["Name"]) if "StorageDescriptor" not in table: @@ -196,7 +207,8 @@ def __get_schema_from_glue(self): def get_schema(self, get_stats=False): if self.configuration.get("glue", False): - return self.__get_schema_from_glue() + catalog_ids = [id.strip() for id in self.configuration.get("catalog_ids", "").split(",")] + return sum([self.__get_schema_from_glue(catalog_id) for catalog_id in catalog_ids], []) schema = {} query = """ diff --git a/tests/query_runner/test_athena.py b/tests/query_runner/test_athena.py index 2ac4ee42fe..6cda21c03b 100644 --- a/tests/query_runner/test_athena.py +++ b/tests/query_runner/test_athena.py @@ -221,3 +221,97 @@ def test_no_storage_descriptor_table(self): ) with self.stubber: assert query_runner.get_schema() == [] + + def test_multi_catalog_tables(self): + """Tables of multi-catalogs""" + query_runner = Athena({"glue": True, "region": "mars-east-1", "catalog_ids": "foo,bar"}) + + self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test1"}]}, {"CatalogId": "foo"}) + self.stubber.add_response( + "get_tables", + { + "TableList": [ + { + "Name": "jdbc_table", + "StorageDescriptor": { + "Columns": [{"Name": "row_id", "Type": "int"}], + "Location": "Database.Schema.Table", + "Compressed": False, + "NumberOfBuckets": -1, + "SerdeInfo": {"Parameters": {}}, + "BucketColumns": [], + "SortColumns": [], + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", + }, + "StoredAsSubDirectories": False, + }, + "PartitionKeys": [], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", + }, + } + ] + }, + {"CatalogId": "foo", "DatabaseName": "test1"}, + ) + self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test2"}]}, {"CatalogId": "bar"}) + self.stubber.add_response( + "get_tables", + { + "TableList": [ + { + "Name": "jdbc_table", + "StorageDescriptor": { + "Columns": [{"Name": "row_id", "Type": "int"}], + "Location": "Database.Schema.Table", + "Compressed": False, + "NumberOfBuckets": -1, + "SerdeInfo": {"Parameters": {}}, + "BucketColumns": [], + "SortColumns": [], + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", + }, + "StoredAsSubDirectories": False, + }, + "PartitionKeys": [], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", + }, + } + ] + }, + {"CatalogId": "bar", "DatabaseName": "test2"}, + ) + with self.stubber: + assert query_runner.get_schema() == [ + {"columns": ["row_id"], "name": "test1.jdbc_table"}, + {"columns": ["row_id"], "name": "test2.jdbc_table"}, + ]