From d1642aaba694b5619a0aab7706da7f3ddc6d9ffa Mon Sep 17 00:00:00 2001 From: Florian Date: Thu, 30 Dec 2021 14:16:30 +0100 Subject: [PATCH 1/2] Add KafkaSchemaRegistrySource in the external providers in the Python binding --- python/metadata_guardian/report.py | 14 ++- python/metadata_guardian/scanner.py | 16 ++- python/metadata_guardian/source/__init__.py | 1 + .../source/external/aws_source.py | 2 +- .../external/kafka_schema_registry_source.py | 104 ++++++++++++++++++ python/pyproject.toml | 3 +- .../external/test_kafka_schema_registry.py | 69 ++++++++++++ .../tests/external/test_snowflake_source.py | 31 ++++++ 8 files changed, 227 insertions(+), 13 deletions(-) create mode 100644 python/metadata_guardian/source/external/kafka_schema_registry_source.py create mode 100644 python/tests/external/test_kafka_schema_registry.py diff --git a/python/metadata_guardian/report.py b/python/metadata_guardian/report.py index d366378..fca97f2 100644 --- a/python/metadata_guardian/report.py +++ b/python/metadata_guardian/report.py @@ -2,6 +2,7 @@ from typing import List, NamedTuple, Optional, Tuple from rich.console import Console +from rich.markup import escape from rich.progress import ( BarColumn, Progress, @@ -23,13 +24,14 @@ class ProgressionBar(Progress): task_id: Optional[TaskID] = None - def __init__(self) -> None: + def __init__(self, disable: bool) -> None: super().__init__( SpinnerColumn(), "[progress.description]{task.description}: [red]{task.fields[current_item]}", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}% ({task.completed}/{task.total})-", TimeRemainingColumn(), + disable=disable, ) def __enter__(self) -> "ProgressionBar": @@ -43,10 +45,10 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore def add_task_with_item( self, - item_name: Optional[str], + item_name: str, source_type: str, total: int, - current_item: str = "", + current_item: str = "Starting", ) -> None: """ Add task in the Progression Bar. @@ -56,10 +58,12 @@ def add_task_with_item( :param total: total of the number of tables :return: the created Task """ + task_details = f"[{item_name}]" if item_name else "" + task_description = f"[bold cyan]Searching in the {escape(source_type)} metadata source{escape(task_details)}" task_id = super().add_task( - f"[bold cyan]Searching in {item_name} for the {source_type} metadata source", + description=task_description, total=total, - current_item=current_item, + current_item=escape(current_item), ) self.task_id = task_id diff --git a/python/metadata_guardian/scanner.py b/python/metadata_guardian/scanner.py index a9de3a6..ea7af86 100644 --- a/python/metadata_guardian/scanner.py +++ b/python/metadata_guardian/scanner.py @@ -13,7 +13,9 @@ class Scanner(ABC): - """Scanner interface.""" + """ + Scanner Interface. + """ @abstractmethod def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: @@ -68,6 +70,7 @@ class ColumnScanner(Scanner): """Column Scanner instance.""" data_rules: DataRules + progression_bar_disable: bool = False def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: """ @@ -78,7 +81,7 @@ def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: logger.debug( f"[blue]Launch the metadata scanning of the local provider {source.type}" ) - with ProgressionBar() as progression_bar: + with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: report = MetadataGuardianReport( report_results=[ ReportResults( @@ -108,9 +111,9 @@ def scan_external( :return: a Metadata Guardian report """ logger.debug( - f"[blue]Launch the metadata scanning of the external provider {source.type} for the database {database_name}" + f"[blue]Launch the metadata scanning of the external provider {source.type} for {database_name}" ) - with ProgressionBar() as progression_bar: + with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: if table_name: progression_bar.add_task_with_item( item_name=database_name, @@ -206,7 +209,7 @@ async def async_validate_words( results=self.data_rules.validate_words(words=words), ) - with ProgressionBar() as progression_bar: + with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: if table_name: tasks = [ async_validate_words( @@ -239,6 +242,7 @@ class ContentFilesScanner: """Content Files Scanner instance.""" data_rules: DataRules + progression_bar_disable: bool = False def scan_local_file(self, path: str) -> MetadataGuardianReport: """ @@ -250,7 +254,7 @@ def scan_local_file(self, path: str) -> MetadataGuardianReport: f"[blue]Launch the metadata scanning the content of the file {path}" ) progression_bar: ProgressionBar - with ProgressionBar() as progression_bar: + with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: progression_bar.add_task_with_item( item_name=path, source_type="files", total=1 ) diff --git a/python/metadata_guardian/source/__init__.py b/python/metadata_guardian/source/__init__.py index 8435ba5..a444342 100644 --- a/python/metadata_guardian/source/__init__.py +++ b/python/metadata_guardian/source/__init__.py @@ -2,6 +2,7 @@ from .external.deltatable_source import * from .external.external_metadata_source import * from .external.gcp_source import * +from .external.kafka_schema_registry_source import * from .external.snowflake_source import * from .local.avro_schema_source import * from .local.avro_source import * diff --git a/python/metadata_guardian/source/external/aws_source.py b/python/metadata_guardian/source/external/aws_source.py index 83e97dc..0ae97f1 100644 --- a/python/metadata_guardian/source/external/aws_source.py +++ b/python/metadata_guardian/source/external/aws_source.py @@ -48,7 +48,7 @@ def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: """ - Get column names from the table. + Get the column names from the table. :param database_name: the database name :param table_name: the table name :param include_comment: include the comment diff --git a/python/metadata_guardian/source/external/kafka_schema_registry_source.py b/python/metadata_guardian/source/external/kafka_schema_registry_source.py new file mode 100644 index 0000000..feb3d1d --- /dev/null +++ b/python/metadata_guardian/source/external/kafka_schema_registry_source.py @@ -0,0 +1,104 @@ +import json +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional + +from loguru import logger + +from .external_metadata_source import ( + ExternalMetadataSource, + ExternalMetadataSourceException, +) + +try: + from confluent_kafka.schema_registry import SchemaRegistryClient + + KAFKA_SCHEMA_REGISTRY_INSTALLED = True +except ImportError: + logger.debug("Kafka Schema Registry optional dependency is not installed.") + KAFKA_SCHEMA_REGISTRY_INSTALLED = False + +if KAFKA_SCHEMA_REGISTRY_INSTALLED: + + class KafkaSchemaRegistryAuthentication(Enum): + """Authentication method for Kafka Schema Registry source.""" + + USER_PWD = 1 + + @dataclass + class KafkaSchemaRegistrySource(ExternalMetadataSource): + """Instance of a Kafka Schema Registry source.""" + + url: str + ssl_certificate_location: Optional[str] = None + ssl_key_location: Optional[str] = None + connection: Optional[Any] = None + authenticator: Optional[ + KafkaSchemaRegistryAuthentication + ] = KafkaSchemaRegistryAuthentication.USER_PWD + comment_field_name: str = "doc" + + def get_connection(self) -> None: + """ + Get the connection of the Kafka Schema Registry. + :return: + """ + if self.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD: + self.connection = SchemaRegistryClient( + { + "url": self.url, + } + ) + else: + raise NotImplementedError() + + def get_column_names( + self, database_name: str, table_name: str, include_comment: bool = False + ) -> List[str]: + """ + Get the column names from the subject. + :param database_name: not relevant + :param table_name: the subject name + :param include_comment: include the comment + :return: the list of the column names + """ + try: + if not self.connection: + self.get_connection() + registered_schema = self.connection.get_latest_version(table_name) + columns = list() + for field in json.loads(registered_schema.schema.schema_str)["fields"]: + columns.append(field["name"].lower()) + if include_comment and self.comment_field_name in field: + columns.append(field[self.comment_field_name].lower()) + return columns + except Exception as exception: + logger.exception( + f"Error in getting columns name from the Kafka Schema Registry {table_name}" + ) + raise exception + + def get_table_names_list(self, database_name: str) -> List[str]: + """ + Get all the subjects from the Schema Registry. + :param database_name: not relevant in that case + :return: the list of the table names of the database + """ + try: + if not self.connection: + self.get_connection() + all_subjects = self.connection.get_subjects() + return all_subjects + except Exception as exception: + logger.exception( + f"Error all the subjects from the subject in the Kafka Schema Registry" + ) + raise ExternalMetadataSourceException(exception) + + @property + def type(self) -> str: + """ + The type of the source. + :return: the name of the source. + """ + return "Kafka Schema Registry" diff --git a/python/pyproject.toml b/python/pyproject.toml index d58a177..d3d33de 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,12 +23,13 @@ dependencies = [ ] [project.optional-dependencies] -all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery"] +all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery", "confluent-kafka"] snowflake = [ "snowflake-connector-python" ] avro = [ "avro" ] aws = [ "boto3", "boto3-stubs[athena,glue]" ] gcp = [ "google-cloud-bigquery"] deltalake = [ "deltalake" ] +kafka_schema_registry = [ "confluent-kafka" ] devel = [ "mypy", "black", diff --git a/python/tests/external/test_kafka_schema_registry.py b/python/tests/external/test_kafka_schema_registry.py new file mode 100644 index 0000000..c29e066 --- /dev/null +++ b/python/tests/external/test_kafka_schema_registry.py @@ -0,0 +1,69 @@ +from unittest.mock import patch + +from confluent_kafka.schema_registry import RegisteredSchema, Schema + +from metadata_guardian.source import ( + KafkaSchemaRegistryAuthentication, + KafkaSchemaRegistrySource, +) + + +@patch("confluent_kafka.schema_registry.SchemaRegistryClient") +def test_kafka_schema_registry_source_get_column_names(mock_connection): + url = "url" + subject_name = "subject_name" + expected = ["key", "value", "doc"] + + source = KafkaSchemaRegistrySource( + url=url, + ) + schema_id = "schema_id" + schema_str = """{ + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "value", + "type": "string", + "doc": "doc" + } + ], + "name": "test_one", + "namespace": "test.one", + "type": "record" + }""" + schema = RegisteredSchema( + schema_id=schema_id, + schema=Schema(schema_str, "AVRO", []), + subject=subject_name, + version=1, + ) + mock_connection.get_latest_version.return_value = schema + source.connection = mock_connection + + column_names = source.get_column_names( + database_name=None, table_name=subject_name, include_comment=True + ) + + assert column_names == expected + assert source.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD + + +@patch("confluent_kafka.schema_registry.SchemaRegistryClient") +def test_kafka_schema_registry_source_get_table_names_list(mock_connection): + url = "url" + expected = ["subject1", "subject2"] + + source = KafkaSchemaRegistrySource( + url=url, + ) + subjects = ["subject1", "subject2"] + mock_connection.get_subjects.return_value = subjects + source.connection = mock_connection + + subjects_list = source.get_table_names_list(database_name=None) + + assert subjects_list == expected + assert source.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD diff --git a/python/tests/external/test_snowflake_source.py b/python/tests/external/test_snowflake_source.py index 22fee59..acdb0d4 100644 --- a/python/tests/external/test_snowflake_source.py +++ b/python/tests/external/test_snowflake_source.py @@ -35,3 +35,34 @@ def test_snowflake_source_get_column_names(mock_connection): assert column_names == expected assert source.authenticator == SnowflakeAuthenticator.USER_PWD + + +@patch("snowflake.connector") +def test_snowflake_source_get_table_names_list(mock_connection): + database_name = "test_database" + schema_name = "PUBLIC" + sf_account = "sf_account" + sf_user = "sf_user" + sf_password = "sf_password" + warehouse = "warehouse" + mocked_cursor_one = mock_connection.connect().cursor.return_value + mocked_cursor_one.description = [["name"], ["phone"]] + mocked_cursor_one.fetchall.return_value = [ + (database_name, "TEST_TABLE"), + (database_name, "TEST_TABLE2"), + ] + mocked_cursor_one.execute.call_args == f'SHOW TABLES IN DATABASE "{database_name}"' + expected = ["TEST_TABLE", "TEST_TABLE2"] + + source = SnowflakeSource( + sf_account=sf_account, + sf_user=sf_user, + sf_password=sf_password, + warehouse=warehouse, + schema_name=schema_name, + ) + + column_names = source.get_table_names_list(database_name=database_name) + + assert column_names == expected + assert source.authenticator == SnowflakeAuthenticator.USER_PWD From 5a7cc233ed51fab6a987934972eb623241678998 Mon Sep 17 00:00:00 2001 From: Florian Date: Fri, 31 Dec 2021 12:31:34 +0100 Subject: [PATCH 2/2] Improve the README and add examples --- README.adoc | 9 +- python/docs/source/installation.rst | 4 +- python/docs/source/usage.rst | 42 ++++++--- python/examples/README.md | 3 + ...scan_external_sources_custom_data_rules.py | 86 +++++++++++++++++ .../scan_external_sources_database.py | 90 ++++++++++++++++++ .../scan_external_sources_database_async.py | 93 +++++++++++++++++++ .../examples/scan_external_sources_table.py | 92 ++++++++++++++++++ python/examples/scan_local_sources.py | 59 ++++++++++++ python/metadata_guardian/report.py | 2 +- python/metadata_guardian/scanner.py | 5 + .../source/external/aws_source.py | 22 +++-- .../source/external/deltatable_source.py | 16 +++- .../external/external_metadata_source.py | 34 ++++++- .../source/external/gcp_source.py | 30 +++--- .../external/kafka_schema_registry_source.py | 17 +++- .../source/external/snowflake_source.py | 17 +++- .../tests/external/test_deltatable_source.py | 5 +- python/tests/external/test_gcp_source.py | 20 ++-- 19 files changed, 584 insertions(+), 62 deletions(-) create mode 100644 python/examples/README.md create mode 100644 python/examples/scan_external_sources_custom_data_rules.py create mode 100644 python/examples/scan_external_sources_database.py create mode 100644 python/examples/scan_external_sources_database_async.py create mode 100644 python/examples/scan_external_sources_table.py create mode 100644 python/examples/scan_local_sources.py diff --git a/README.adoc b/README.adoc index da32d23..23350b2 100644 --- a/README.adoc +++ b/README.adoc @@ -23,9 +23,10 @@ Using Rust, it makes blazing fast multi-regex matching. - Deltalake - GCP: BigQuery - Snowflake +- Kafka Schema Registry == Data Rules -The available data rules are: *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/pii_rules.yaml[PII]* and *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/inclusion_rules.yaml[INCLUSION]*. But it aims to be extended with custom data rules that could serve multiple purposes (for example: detect data that may contain IA biais, detect credentials...). +The available data rules are here: *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/pii_rules.yaml[PII]* and *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/inclusion_rules.yaml[INCLUSION]*. But it aims to be extended with custom data rules that could serve multiple purposes (for example: detect data that may contain IA biais, detect credentials...). == Where to get it @@ -35,12 +36,12 @@ pip install 'metadata_guardian[all]' ``` ```sh -# Install with one data source -pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake]' +# Install with one metadata source in the list +pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,kafka_schema_registry]' ``` == Licence https://raw.githubusercontent.com/fvaleye/metadata-guardian/main/LICENSE.txt[Apache License 2.0] == Documentation -The documentation is hosted here: https://fvaleye.github.io/metadata-guardian/python/ \ No newline at end of file +The documentation is hosted here: https://fvaleye.github.io/metadata-guardian/python/ diff --git a/python/docs/source/installation.rst b/python/docs/source/installation.rst index 0133bb6..df04deb 100644 --- a/python/docs/source/installation.rst +++ b/python/docs/source/installation.rst @@ -8,5 +8,5 @@ Using Pip # Install all the metadata sources pip install 'metadata_guardian[all]' - # Install one metadata source in the list - pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,devel]' \ No newline at end of file + # Install with one metadata source in the list + pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,kafka_schema_registry]' \ No newline at end of file diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index 13a6d8f..a86326a 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -4,16 +4,15 @@ Usage Metadata Guardian ----------------- -Scan the column names of a local source: +**Workflow:** ->>> from metadata_guardian import DataRules, ColumnScanner, AvailableCategory ->>> from metadata_guardian.source import ParquetSource ->>> ->>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) ->>> source = ParquetSource("file.parquet") ->>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_local(source) ->>> report.to_console() +1. Create the Data Rules +2. Create the Metadata Source +3. Scan the Metadata Source +4. Analyze the reports + +Scan an external Metadata Source +-------------------------------- Scan the column names of a external source on a table: @@ -23,7 +22,8 @@ Scan the column names of a external source on a table: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_external(source, database_name="database_name", table_name="table_name", include_comment=True) +>>> with source: +>>> report = column_scanner.scan_external(source, database_name="database_name", table_name="table_name", include_comment=True) >>> report.to_console() Scan the column names of a external source on database: @@ -34,7 +34,8 @@ Scan the column names of a external source on database: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_external(source, database_name="database_name", include_comment=True) +>>> with source: +>>> report = column_scanner.scan_external(source, database_name="database_name", include_comment=True) >>> report.to_console() Scan the column names of an external source for a database asynchronously with asyncio: @@ -46,7 +47,23 @@ Scan the column names of an external source for a database asynchronously with a >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = asyncio.run(column_scanner.scan_external_async(source, database_name="database_name", include_comment=True)) +>>> with source: +>>> report = asyncio.run(column_scanner.scan_external_async(source, database_name="database_name", include_comment=True)) +>>> report.to_console() + + +Scan an internal Metadata Source +-------------------------------- + +Scan the column names of a local source: + +>>> from metadata_guardian import DataRules, ColumnScanner, AvailableCategory +>>> from metadata_guardian.source import ParquetSource +>>> +>>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) +>>> column_scanner = ColumnScanner(data_rules=data_rules) +>>> with ParquetSource("file.parquet") as source: +>>> report = column_scanner.scan_local(source) >>> report.to_console() Scan the column names of a local source: @@ -57,6 +74,7 @@ Scan the column names of a local source: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> column_scanner = ColumnScanner(data_rules=data_rules) >>> report = MetadataGuardianReport() +>>> paths = ["first_path", "second_path"] >>> for path in paths: >>> source = ParquetSource(path) >>> report.append(column_scanner.scan_local(source)) diff --git a/python/examples/README.md b/python/examples/README.md new file mode 100644 index 0000000..a50b9ea --- /dev/null +++ b/python/examples/README.md @@ -0,0 +1,3 @@ +Examples +This directory contains various examples of the Metadata Guardian features. +Make sure Metadata Guardian is installed and run the examples using the command line with python. diff --git a/python/examples/scan_external_sources_custom_data_rules.py b/python/examples/scan_external_sources_custom_data_rules.py new file mode 100644 index 0000000..3b7565f --- /dev/null +++ b/python/examples/scan_external_sources_custom_data_rules.py @@ -0,0 +1,86 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules-path", + required=True, + help="The Data Rules specification yaml file path to use for creating the Data Rules", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules(path=args.data_rules_path) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_external_sources_database.py b/python/examples/scan_external_sources_database.py new file mode 100644 index 0000000..1a13670 --- /dev/null +++ b/python/examples/scan_external_sources_database.py @@ -0,0 +1,90 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_external_sources_database_async.py b/python/examples/scan_external_sources_database_async.py new file mode 100644 index 0000000..30c1beb --- /dev/null +++ b/python/examples/scan_external_sources_database_async.py @@ -0,0 +1,93 @@ +import argparse +import asyncio +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = asyncio.run( + column_scanner.scan_external_async( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + ) + report.to_console() diff --git a/python/examples/scan_external_sources_table.py b/python/examples/scan_external_sources_table.py new file mode 100644 index 0000000..8074d18 --- /dev/null +++ b/python/examples/scan_external_sources_table.py @@ -0,0 +1,92 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument("--table_name", required=True, help="The table name to scan") + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + table_name=args.table_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_local_sources.py b/python/examples/scan_local_sources.py new file mode 100644 index 0000000..f4845e2 --- /dev/null +++ b/python/examples/scan_local_sources.py @@ -0,0 +1,59 @@ +import argparse +import os + +from metadata_guardian import AvailableCategory, ColumnScanner, DataRules +from metadata_guardian.source import AvroSource, ORCSource, ParquetSource + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--local-source", + choices=["Avro", "Parquet", "Orc"], + required=True, + help="The Local Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument("--path", required=True, help="The path of the file to scan") + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.local_source == "Avro": + source = AvroSource(local_path=args.path) + elif args.local_source == "Parquet": + source = ParquetSource(local_path=args.path) + elif args.local_source == "Orc": + source = ORCSource(local_path=args.path) + + report = column_scanner.scan_local(source) + report.to_console() diff --git a/python/metadata_guardian/report.py b/python/metadata_guardian/report.py index fca97f2..4b38e1b 100644 --- a/python/metadata_guardian/report.py +++ b/python/metadata_guardian/report.py @@ -108,7 +108,7 @@ def to_console(self) -> None: _table = Table( title=":magnifying_glass_tilted_right: Metadata Guardian report", show_header=True, - header_style="bold dim", + header_style="bold", show_lines=True, ) _table.add_column("Category", style="yellow", no_wrap=True) diff --git a/python/metadata_guardian/scanner.py b/python/metadata_guardian/scanner.py index ea7af86..83b6a7e 100644 --- a/python/metadata_guardian/scanner.py +++ b/python/metadata_guardian/scanner.py @@ -82,6 +82,11 @@ def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: f"[blue]Launch the metadata scanning of the local provider {source.type}" ) with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: + progression_bar.add_task_with_item( + item_name=source.local_path, + source_type=source.type, + total=1, + ) report = MetadataGuardianReport( report_results=[ ReportResults( diff --git a/python/metadata_guardian/source/external/aws_source.py b/python/metadata_guardian/source/external/aws_source.py index 0ae97f1..afbb26d 100644 --- a/python/metadata_guardian/source/external/aws_source.py +++ b/python/metadata_guardian/source/external/aws_source.py @@ -32,9 +32,9 @@ class AthenaSource(ExternalMetadataSource): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get Athena connection. + Create Athena connection. :return: """ self.connection = boto3.client( @@ -44,6 +44,9 @@ def get_connection(self) -> None: aws_secret_access_key=self.aws_secret_access_key, ) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -56,7 +59,7 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() response = self.connection.get_table_metadata( CatalogName=self.catalog_name, DatabaseName=database_name, @@ -83,7 +86,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection: - self.get_connection() + self.create_connection() table_names_list = list() response = self.connection.list_table_metadata( CatalogName=self.catalog_name, @@ -122,9 +125,9 @@ class GlueSource(ExternalMetadataSource): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the Glue connection + Create the Glue connection :return: """ self.connection = boto3.client( @@ -134,6 +137,9 @@ def get_connection(self) -> None: aws_secret_access_key=self.aws_secret_access_key, ) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -146,7 +152,7 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() response = self.connection.get_table( DatabaseName=database_name, Name=table_name ) @@ -171,7 +177,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection: - self.get_connection() + self.create_connection() table_names_list = list() response = self.connection.get_tables( DatabaseName=database_name, diff --git a/python/metadata_guardian/source/external/deltatable_source.py b/python/metadata_guardian/source/external/deltatable_source.py index bb4240a..264b7c3 100644 --- a/python/metadata_guardian/source/external/deltatable_source.py +++ b/python/metadata_guardian/source/external/deltatable_source.py @@ -22,14 +22,18 @@ class DeltaTableSource(ExternalMetadataSource): uri: str data_catalog: DataCatalog = DataCatalog.AWS + external_data_catalog_disable: bool = True - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the DeltaTable instance. + Create the DeltaTable instance. :return: """ self.connection = DeltaTable(self.uri) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: Optional[str] = None, @@ -44,14 +48,18 @@ def get_column_names( :return: the list of the column names """ try: - if database_name and table_name: + if ( + not self.external_data_catalog_disable + and database_name + and table_name + ): self.connection = DeltaTable.from_data_catalog( data_catalog=self.data_catalog, database_name=database_name, table_name=table_name, ) elif not self.connection: - self.get_connection() + self.create_connection() schema = self.connection.schema() columns = list() for field in schema.fields: diff --git a/python/metadata_guardian/source/external/external_metadata_source.py b/python/metadata_guardian/source/external/external_metadata_source.py index ac22245..a72deb1 100644 --- a/python/metadata_guardian/source/external/external_metadata_source.py +++ b/python/metadata_guardian/source/external/external_metadata_source.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Any, List, Optional +from loguru import logger + from ...exceptions import MetadataGuardianException from ..metadata_source import MetadataSource @@ -10,6 +12,26 @@ class ExternalMetadataSource(MetadataSource): connection: Optional[Any] = None + def __enter__(self) -> "ExternalMetadataSource": + try: + self.create_connection() + except Exception as exception: + logger.exception( + "Error raised while opening the Metadata Source connection" + ) + raise exception + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + try: + self.close_connection() + except Exception as exception: + logger.exception( + "Error raised while closing the Metadata Source connection" + ) + raise exception + return self + @abstractmethod def get_column_names( self, @@ -36,9 +58,17 @@ def get_table_names_list(self, database_name: str) -> List[str]: pass @abstractmethod - def get_connection(self) -> None: + def create_connection(self) -> None: + """ + Create the connection of the source. + :return: + """ + pass + + @abstractmethod + def close_connection(self) -> None: """ - Get the connection of the source. + Close the connection of the source. :return: """ pass diff --git a/python/metadata_guardian/source/external/gcp_source.py b/python/metadata_guardian/source/external/gcp_source.py index 3954eab..1624888 100644 --- a/python/metadata_guardian/source/external/gcp_source.py +++ b/python/metadata_guardian/source/external/gcp_source.py @@ -26,7 +26,7 @@ class BigQuerySource(ExternalMetadataSource): project: Optional[str] = None location: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ Get the Big Query connection. :return: @@ -41,6 +41,13 @@ def get_connection(self) -> None: logger.exception("Error when connecting to BigQuery") raise exception + def close_connection(self) -> None: + """ + Close the BigQuery connection + :return: + """ + self.connection.close() + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -54,16 +61,17 @@ def get_column_names( try: if not self.connection: - self.get_connection() - query_job = self.connection.query( - f'SELECT column_name, description FROM `{database_name}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` WHERE table_name = "{table_name}"' - ) - results = query_job.result() + self.create_connection() + + table_reference = self.connection.dataset( + database_name, project=self.project + ).table(table_name) + table = self.connection.get_table(table_reference) columns = list() - for row in results: - columns.append(row.column_name.lower()) - if include_comment and row.description: - columns.append(row.description.lower()) + for column in table.schema: + columns.append(column.name.lower()) + if include_comment and column.description: + columns.append(column.description.lower()) return columns except Exception as exception: logger.exception( @@ -80,7 +88,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: try: if not self.connection: - self.get_connection() + self.create_connection() query_job = self.connection.query( f"SELECT table_name FROM `{database_name}.INFORMATION_SCHEMA.TABLES`" ) diff --git a/python/metadata_guardian/source/external/kafka_schema_registry_source.py b/python/metadata_guardian/source/external/kafka_schema_registry_source.py index feb3d1d..87ee144 100644 --- a/python/metadata_guardian/source/external/kafka_schema_registry_source.py +++ b/python/metadata_guardian/source/external/kafka_schema_registry_source.py @@ -32,15 +32,15 @@ class KafkaSchemaRegistrySource(ExternalMetadataSource): url: str ssl_certificate_location: Optional[str] = None ssl_key_location: Optional[str] = None - connection: Optional[Any] = None + connection: Optional[SchemaRegistryClient] = None authenticator: Optional[ KafkaSchemaRegistryAuthentication ] = KafkaSchemaRegistryAuthentication.USER_PWD comment_field_name: str = "doc" - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the connection of the Kafka Schema Registry. + Create the connection of the Kafka Schema Registry. :return: """ if self.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD: @@ -52,6 +52,13 @@ def get_connection(self) -> None: else: raise NotImplementedError() + def close_connection(self) -> None: + """ + Close the Kafka Schema Registry connection. + :return: + """ + self.connection.__exit__() + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -64,7 +71,7 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() registered_schema = self.connection.get_latest_version(table_name) columns = list() for field in json.loads(registered_schema.schema.schema_str)["fields"]: @@ -86,7 +93,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection: - self.get_connection() + self.create_connection() all_subjects = self.connection.get_subjects() return all_subjects except Exception as exception: diff --git a/python/metadata_guardian/source/external/snowflake_source.py b/python/metadata_guardian/source/external/snowflake_source.py index cb86bb0..a8221ea 100644 --- a/python/metadata_guardian/source/external/snowflake_source.py +++ b/python/metadata_guardian/source/external/snowflake_source.py @@ -43,9 +43,9 @@ class SnowflakeSource(ExternalMetadataSource): oauth_host: Optional[str] = None authenticator: SnowflakeAuthenticator = SnowflakeAuthenticator.USER_PWD - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get a Snowflake connection based on the SnowflakeAuthenticator. + Create a Snowflake connection based on the SnowflakeAuthenticator. :return: """ if self.authenticator == SnowflakeAuthenticator.USER_PWD: @@ -76,6 +76,13 @@ def get_connection(self) -> None: converter_class=SnowflakeNoConverterToPython, ) + def close_connection(self) -> None: + """ + Close the Snowflake connection. + :return: + """ + self.connection.close() + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -88,7 +95,7 @@ def get_column_names( """ try: if not self.connection or self.connection.is_closed(): - self.get_connection() + self.create_connection() cursor = self.connection.cursor() cursor.execute( f'SHOW COLUMNS IN "{database_name}"."{self.schema_name}"."{table_name}"' @@ -104,7 +111,7 @@ def get_column_names( return columns except Exception as exception: logger.exception( - f"Error in getting columns name from Snowflake {self.schema_name}.{database_name}.{table_name}" + f"Error in getting columns name from Snowflake {database_name}.{self.schema_name}.{table_name}" ) raise exception finally: @@ -118,7 +125,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection or self.connection.is_closed(): - self.get_connection() + self.create_connection() cursor = self.connection.cursor() cursor.execute(f'SHOW TABLES IN DATABASE "{database_name}"') rows = cursor.fetchall() diff --git a/python/tests/external/test_deltatable_source.py b/python/tests/external/test_deltatable_source.py index d97da5d..77bc6ad 100644 --- a/python/tests/external/test_deltatable_source.py +++ b/python/tests/external/test_deltatable_source.py @@ -45,6 +45,7 @@ def test_deltatable_source_get_column_names_from_database_and_table(mock_connect uri = "s3://test_table" database_name = "database_name" table_name = "table_name" + external_data_catalog_disable = False schema = Schema( fields=[ Field( @@ -71,7 +72,9 @@ def test_deltatable_source_get_column_names_from_database_and_table(mock_connect "{'comment': 'comment2'}", ] - column_names = DeltaTableSource(uri=uri).get_column_names( + column_names = DeltaTableSource( + uri=uri, external_data_catalog_disable=external_data_catalog_disable + ).get_column_names( database_name=database_name, table_name=table_name, include_comment=True ) diff --git a/python/tests/external/test_gcp_source.py b/python/tests/external/test_gcp_source.py index c676424..907bcc4 100644 --- a/python/tests/external/test_gcp_source.py +++ b/python/tests/external/test_gcp_source.py @@ -1,6 +1,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from google.cloud import bigquery + from metadata_guardian.source import BigQuerySource @@ -9,14 +11,18 @@ def test_big_query_source_get_column_names(mock_connection): service_account_json_path = "" dataset_name = "test_dataset" table_name = "test_table" - results = [ - SimpleNamespace(column_name="timestamp", description="description1"), - SimpleNamespace(column_name="address_id", description="description2"), - ] + results = SimpleNamespace( + schema=[ + bigquery.SchemaField( + "timestamp", "STRING", mode="REQUIRED", description="description1" + ), + bigquery.SchemaField( + "address_id", "STRING", mode="REQUIRED", description="description2" + ), + ] + ) mock_connection.return_value = mock_connection - response = Mock() - response.result.return_value = results - mock_connection.query.return_value = response + mock_connection.get_table.return_value = results expected = ["timestamp", "description1", "address_id", "description2"] column_names = BigQuerySource(