From c6d2ebe352073af809d23b54c524863b4562a789 Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Thu, 25 May 2023 09:33:35 +0200 Subject: [PATCH 1/7] initial changes for nested type support --- discoverx/scanner.py | 76 +++++++++++++++++++++++++++++++- tests/unit/conftest.py | 36 ++++++++++++++- tests/unit/data/columns_mock.csv | 3 ++ tests/unit/data/tb_2.json | 3 ++ tests/unit/scanner_test.py | 2 +- 5 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 tests/unit/data/tb_2.json diff --git a/discoverx/scanner.py b/discoverx/scanner.py index caf177e..8b87d94 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -1,6 +1,8 @@ from dataclasses import dataclass import pandas as pd from pyspark.sql import SparkSession +from pyspark.sql.types import * +from pyspark.sql.types import _parse_datatype_string from typing import Optional, List, Set from discoverx.common.helper import strip_margin, format_regex @@ -68,6 +70,7 @@ def n_scanned_columns(self) -> int: class Scanner: COLUMNS_TABLE_NAME = "system.information_schema.columns" + COMPLEX_TYPES = {StructType, ArrayType, MapType} def __init__( self, @@ -92,6 +95,7 @@ def __init__( self.content: ScanContent = self._resolve_scan_content() self.rule_list = self.rules.get_rules(rule_filter=self.rules_filter) self.scan_result: Optional[ScanResult] = None + self.column_list = [] def _get_list_of_tables(self) -> List[TableInfo]: table_list_sql = self._get_table_list_sql() @@ -208,6 +212,36 @@ def scan(self): else: self.scan_result = ScanResult(df=pd.DataFrame()) + @staticmethod + def backtick_col_name(col_name: str) -> str: + col_name_splitted = col_name.split(".") + return ".".join(["`" + col + "`" for col in col_name_splitted]) + + def recursive_flatten_complex_type(self, col_name, schema): + if type(schema) in self.COMPLEX_TYPES: + iterable = schema + elif type(schema) is StructField: + iterable = schema.dataType + elif schema == StringType(): + self.column_list.append({"name": col_name, "type": "string"}) + return + else: + return + + if type(iterable) is StructType: + for field in iterable: + if type(field.dataType) == StringType: + self.column_list.append({"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"}) + elif type(field.dataType) in self.COMPLEX_TYPES: + self.recursive_flatten_complex_type(col_name + "." + field.name, field) + elif type(iterable) is MapType: + if type(iterable.valueType) not in self.COMPLEX_TYPES: + self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"}) + if type(iterable.keyType) not in self.COMPLEX_TYPES: + self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"}) + elif type(iterable) is ArrayType: + self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"}) + def _rule_matching_sql(self, table_info: TableInfo): """ Given a table and a set of rules this method will return a @@ -223,7 +257,47 @@ def _rule_matching_sql(self, table_info: TableInfo): """ expressions = [r for r in self.rule_list if r.type == RuleTypes.REGEX] - cols = [c for c in table_info.columns if c.data_type.lower() == "string"] + expr_pdf = pd.DataFrame([{"rule_name": r.name, "rule_definition": r.definition, "key": 0} for r in expressions]) + for col in table_info.columns: + self.recursive_flatten_complex_type(col.name, _parse_datatype_string(col.data_type)) + columns_pdf = pd.DataFrame(self.column_list) + # prepare for cross-join + columns_pdf["key"] = 0 + # cross-join + col_expr_pdf = columns_pdf.merge(expr_pdf, on=["key"]) + + def sum_expressions(row): + if row.type == "string": + return f"int(regexp_like({row.col_name}, '{row.rule_definition}'))" + elif row.type == "array": + return f"size(filter({row.col_name}, x -> x rlike '{row.rule_definition}'))" + elif row.type == "map_values": + return f"size(filter(map_values({row.col_name}), x -> x rlike '{row.rule_definition}'))" + elif row.type == "map_keys": + return f"size(filter(map_keys({row.col_name}), x -> x rlike '{row.rule_definition}'))" + else: + return None + + def count_expressions(row): + if row.type == "string": + return "1" + elif row.type == "array": + return f"size({row.col_name})" + elif row.type == "map_values": + return f"size(map_values({row.col_name}))" + elif row.type == "map_keys": + return f"size(map_keys({row.col_name}))" + else: + return None + + col_expr_pdf["sum_expression"] = col_expr_pdf.apply(sum_expressions, axis=1) + col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) + # build stack expression + stack_expression = list(zip(col_expr_pdf.col_name, col_expr_pdf.rule_name, col_expr_pdf.sum_expression, col_expr_pdf.count_expression)) + stack_expression = [item for sublist in stack_expression for item in sublist] + stack_expr_string = f"stack(4, + {', '.join(stack_expression)}) as (column, rule_name, sum_value, count_value)" + + #cols = [c for c in table_info.columns if c.data_type.lower() == "string"] if not cols: raise Exception( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 86000dd..aac86f4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,13 +8,12 @@ import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Iterator -from unittest.mock import patch import mlflow import pytest from delta import configure_spark_with_delta_pip from pyspark.sql import SparkSession +from pyspark.sql.types import * from discoverx.classification import DeltaTable from discoverx.dx import Classifier from discoverx.classification import func @@ -139,6 +138,39 @@ def sample_datasets(spark: SparkSession, request): ).createOrReplaceTempView("view_tb_1") spark.sql(f"CREATE TABLE IF NOT EXISTS default.tb_1 USING delta LOCATION '{warehouse_dir}/tb_1' AS SELECT * FROM view_tb_1 ") + # tb_2 + test_file_tb2_path = module_path.parent / "data/tb_2.json" + schema_json_example = ( + StructType() + .add( + "customer", + StructType() + .add("name", StringType(), True) + .add("id", IntegerType(), True) + .add( + "contact", + StructType() + .add( + "address", + StructType() + .add("street", StringType(), True) + .add("town", StringType(), True) + .add("postal_number", StringType(), True) + .add("country", StringType(), True), + True, + ) + .add("email", StringType()), + ) + .add("products_owned", ArrayType(StringType()), True) + .add("interactions", MapType(StringType(), StringType())), + True, + ) + .add("active", BooleanType(), True) + .add("categories", MapType(StringType(), StringType())) +) + spark.read.schema(schema_json_example).json(str(test_file_tb2_path.resolve())).createOrReplaceTempView("view_tb_2") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.tb_1 USING delta LOCATION '{warehouse_dir}/tb_2' AS SELECT * FROM view_tb_2 ") # columns_mock test_file_path = module_path.parent / "data/columns_mock.csv" (spark diff --git a/tests/unit/data/columns_mock.csv b/tests/unit/data/columns_mock.csv index b4d5dbe..84fd8d4 100644 --- a/tests/unit/data/columns_mock.csv +++ b/tests/unit/data/columns_mock.csv @@ -23,3 +23,6 @@ hive_metastore,default,tb_all_types,str_part_col,STRING,1 ,default,tb_1,ip,STRING, ,default,tb_1,mac,STRING, ,default,tb_1,description,STRING, +,default,tb_2,active,BOOLEAN, +,default,tb_2,categories,"map", +,default,tb_2,customer,"struct,email:string>,products_owned:array>", diff --git a/tests/unit/data/tb_2.json b/tests/unit/data/tb_2.json new file mode 100644 index 0000000..98f6bf2 --- /dev/null +++ b/tests/unit/data/tb_2.json @@ -0,0 +1,3 @@ +{"customer": {"name": "AAA BBBB", "id": 1, "contact": {"address": {"street": "AAA street 11", "town": "AAA town", "postal_number": "111333", "country": "AAA country"}, "email": "aaa.bbb@aaa.com"}, "products_owned": ["product1", "product2", "product10"], "interactions": {"service": "test aaa", "shop": "test shop aaa"}}, "active": true, "categories": {"cat1": "D"}} +{"customer": {"name": "BBB CCCC", "id": 2, "contact": {"address": {"street": "BBB street 12", "town": "BBB town", "postal_number": "111233", "country": "BBB country"}, "email": "bbb.ccc@bbb.com"}, "products_owned": ["product1", "product10"], "interactions": {"service": "test bbb", "request": "test r bbb"}}, "active": false, "categories": {"cat1": "A", "cat2": "B", "cat3": "C"}} +{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country"}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index 73d660b..a2d14d5 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -207,7 +207,7 @@ def test_scan(spark: SparkSession): rules = Rules() MockedScanner = Scanner MockedScanner.COLUMNS_TABLE_NAME = "default.columns_mock" - scanner = MockedScanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*") + scanner = MockedScanner(spark, rules=rules, tables="tb_*", rule_filter="ip_*") scanner.scan() assert scanner.scan_result.df.equals(expected) From 5d3deb1aebce8a99e5e3a0cdc49101997c86c7ab Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Fri, 26 May 2023 10:07:34 +0200 Subject: [PATCH 2/7] working solution for strings in structs --- discoverx/scanner.py | 70 ++++++++++++++++++++---------------------- tests/unit/conftest.py | 2 +- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 8b87d94..8e83757 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -261,43 +261,39 @@ def _rule_matching_sql(self, table_info: TableInfo): for col in table_info.columns: self.recursive_flatten_complex_type(col.name, _parse_datatype_string(col.data_type)) columns_pdf = pd.DataFrame(self.column_list) - # prepare for cross-join - columns_pdf["key"] = 0 - # cross-join - col_expr_pdf = columns_pdf.merge(expr_pdf, on=["key"]) - - def sum_expressions(row): - if row.type == "string": - return f"int(regexp_like({row.col_name}, '{row.rule_definition}'))" - elif row.type == "array": - return f"size(filter({row.col_name}, x -> x rlike '{row.rule_definition}'))" - elif row.type == "map_values": - return f"size(filter(map_values({row.col_name}), x -> x rlike '{row.rule_definition}'))" - elif row.type == "map_keys": - return f"size(filter(map_keys({row.col_name}), x -> x rlike '{row.rule_definition}'))" - else: - return None - - def count_expressions(row): - if row.type == "string": - return "1" - elif row.type == "array": - return f"size({row.col_name})" - elif row.type == "map_values": - return f"size(map_values({row.col_name}))" - elif row.type == "map_keys": - return f"size(map_keys({row.col_name}))" - else: - return None - - col_expr_pdf["sum_expression"] = col_expr_pdf.apply(sum_expressions, axis=1) - col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) - # build stack expression - stack_expression = list(zip(col_expr_pdf.col_name, col_expr_pdf.rule_name, col_expr_pdf.sum_expression, col_expr_pdf.count_expression)) - stack_expression = [item for sublist in stack_expression for item in sublist] - stack_expr_string = f"stack(4, + {', '.join(stack_expression)}) as (column, rule_name, sum_value, count_value)" - + # # prepare for cross-join + # columns_pdf["key"] = 0 + # # cross-join + # col_expr_pdf = columns_pdf.merge(expr_pdf, on=["key"]) + # + # def sum_expressions(row): + # if row.type == "string": + # return f"int(regexp_like({row.col_name}, '{row.rule_definition}'))" + # elif row.type == "array": + # return f"size(filter({row.col_name}, x -> x rlike '{row.rule_definition}'))" + # elif row.type == "map_values": + # return f"size(filter(map_values({row.col_name}), x -> x rlike '{row.rule_definition}'))" + # elif row.type == "map_keys": + # return f"size(filter(map_keys({row.col_name}), x -> x rlike '{row.rule_definition}'))" + # else: + # return None + # + # def count_expressions(row): + # if row.type == "string": + # return "1" + # elif row.type == "array": + # return f"size({row.col_name})" + # elif row.type == "map_values": + # return f"size(map_values({row.col_name}))" + # elif row.type == "map_keys": + # return f"size(map_keys({row.col_name}))" + # else: + # return None + # + # col_expr_pdf["sum_expression"] = col_expr_pdf.apply(sum_expressions, axis=1) + # col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) #cols = [c for c in table_info.columns if c.data_type.lower() == "string"] + cols = columns_pdf.loc[columns_pdf.type == "string", "col_name"].to_list() if not cols: raise Exception( @@ -317,7 +313,7 @@ def count_expressions(row): unpivot_expressions = ", ".join( [f"'{r.name}', `{r.name}`" for r in expressions] ) - unpivot_columns = ", ".join([f"'{c.name}', `{c.name}`" for c in cols]) + unpivot_columns = ", ".join([f"'{c}', {c}" for c in cols]) sql = f""" SELECT diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index aac86f4..116961e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -170,7 +170,7 @@ def sample_datasets(spark: SparkSession, request): ) spark.read.schema(schema_json_example).json(str(test_file_tb2_path.resolve())).createOrReplaceTempView("view_tb_2") spark.sql( - f"CREATE TABLE IF NOT EXISTS default.tb_1 USING delta LOCATION '{warehouse_dir}/tb_2' AS SELECT * FROM view_tb_2 ") + f"CREATE TABLE IF NOT EXISTS default.tb_2 USING delta LOCATION '{warehouse_dir}/tb_2' AS SELECT * FROM view_tb_2 ") # columns_mock test_file_path = module_path.parent / "data/columns_mock.csv" (spark From dd0dafe168cdd33bef120480e249c03b9c6deb4a Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Wed, 31 May 2023 11:07:18 +0200 Subject: [PATCH 3/7] working solution for strings in structs --- discoverx/scanner.py | 129 ++++++++++++++++++------------------- tests/unit/scanner_test.py | 12 ++++ 2 files changed, 76 insertions(+), 65 deletions(-) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 8e83757..185bb37 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -3,6 +3,7 @@ from pyspark.sql import SparkSession from pyspark.sql.types import * from pyspark.sql.types import _parse_datatype_string +from pyspark.sql.utils import ParseException from typing import Optional, List, Set from discoverx.common.helper import strip_margin, format_regex @@ -62,9 +63,7 @@ class ScanResult: @property def n_scanned_columns(self) -> int: - return len( - self.df[["catalog", "database", "table", "column"]].drop_duplicates() - ) + return len(self.df[["catalog", "database", "table", "column"]].drop_duplicates()) class Scanner: @@ -95,7 +94,6 @@ def __init__( self.content: ScanContent = self._resolve_scan_content() self.rule_list = self.rules.get_rules(rule_filter=self.rules_filter) self.scan_result: Optional[ScanResult] = None - self.column_list = [] def _get_list_of_tables(self) -> List[TableInfo]: table_list_sql = self._get_table_list_sql() @@ -107,9 +105,7 @@ def _get_list_of_tables(self) -> List[TableInfo]: row["table_schema"], row["table_name"], [ - ColumnInfo( - col["column_name"], col["data_type"], col["partition_index"], [] - ) + ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], []) for col in row["table_columns"] ], ) @@ -128,9 +124,7 @@ def _get_table_list_sql(self): catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")""" database_sql = f"""AND regexp_like(table_schema, "^{self.databases.replace("*", ".*")}$")""" - table_sql = ( - f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" - ) + table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" sql = f""" SELECT @@ -158,9 +152,7 @@ def _resolve_scan_content(self) -> ScanContent: def scan(self): - logger.friendly( - """Ok, I'm going to scan your lakehouse for data that matches your rules.""" - ) + logger.friendly("""Ok, I'm going to scan your lakehouse for data that matches your rules.""") text = f""" This is what you asked for: @@ -200,15 +192,13 @@ def scan(self): # Execute SQL and append result dfs.append(self.spark.sql(sql).toPandas()) except Exception as e: - logger.error( - f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}" - ) + logger.error(f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}") continue logger.debug("Finished lakehouse scanning task") if dfs: - self.scan_result = ScanResult(df=pd.concat(dfs)) + self.scan_result = ScanResult(df=pd.concat(dfs).reset_index(drop=True)) else: self.scan_result = ScanResult(df=pd.DataFrame()) @@ -217,30 +207,34 @@ def backtick_col_name(col_name: str) -> str: col_name_splitted = col_name.split(".") return ".".join(["`" + col + "`" for col in col_name_splitted]) - def recursive_flatten_complex_type(self, col_name, schema): + def recursive_flatten_complex_type(self, col_name, schema, column_list): if type(schema) in self.COMPLEX_TYPES: iterable = schema elif type(schema) is StructField: iterable = schema.dataType elif schema == StringType(): - self.column_list.append({"name": col_name, "type": "string"}) - return + column_list.append({"col_name": col_name, "type": "string"}) + return column_list else: - return + return column_list if type(iterable) is StructType: for field in iterable: if type(field.dataType) == StringType: - self.column_list.append({"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"}) + column_list.append( + {"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"} + ) elif type(field.dataType) in self.COMPLEX_TYPES: - self.recursive_flatten_complex_type(col_name + "." + field.name, field) + column_list = self.recursive_flatten_complex_type(col_name + "." + field.name, field, column_list) elif type(iterable) is MapType: if type(iterable.valueType) not in self.COMPLEX_TYPES: - self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"}) + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"}) if type(iterable.keyType) not in self.COMPLEX_TYPES: - self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"}) + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"}) elif type(iterable) is ArrayType: - self.column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"}) + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"}) + + return column_list def _rule_matching_sql(self, table_info: TableInfo): """ @@ -258,9 +252,16 @@ def _rule_matching_sql(self, table_info: TableInfo): expressions = [r for r in self.rule_list if r.type == RuleTypes.REGEX] expr_pdf = pd.DataFrame([{"rule_name": r.name, "rule_definition": r.definition, "key": 0} for r in expressions]) + column_list = [] for col in table_info.columns: - self.recursive_flatten_complex_type(col.name, _parse_datatype_string(col.data_type)) - columns_pdf = pd.DataFrame(self.column_list) + try: + data_type = _parse_datatype_string(col.data_type) + except ParseException: + data_type = None + + if data_type: + self.recursive_flatten_complex_type(col.name, data_type, column_list) + columns_pdf = pd.DataFrame(column_list) # # prepare for cross-join # columns_pdf["key"] = 0 # # cross-join @@ -292,54 +293,52 @@ def _rule_matching_sql(self, table_info: TableInfo): # # col_expr_pdf["sum_expression"] = col_expr_pdf.apply(sum_expressions, axis=1) # col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) - #cols = [c for c in table_info.columns if c.data_type.lower() == "string"] - cols = columns_pdf.loc[columns_pdf.type == "string", "col_name"].to_list() - - if not cols: - raise Exception( - f"There are no columns of type string to be scanned in {table_info.table}" - ) + # cols = [c for c in table_info.columns if c.data_type.lower() == "string"] + if len(columns_pdf) == 0: + raise Exception(f"There are no columns of type string to be scanned in {table_info.table}") if not expressions: raise Exception(f"There are no rules to scan for.") + string_cols = columns_pdf.loc[columns_pdf.type == "string", "col_name"].to_list() + all_sql = self.string_col_sql(string_cols, expressions, table_info) + + return all_sql + + def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: catalog_str = f"{table_info.catalog}." if table_info.catalog else "" matching_columns = [ - f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" - for r in expressions + f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" for r in expressions ] matching_string = ",\n ".join(matching_columns) - unpivot_expressions = ", ".join( - [f"'{r.name}', `{r.name}`" for r in expressions] - ) + unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}`" for r in expressions]) unpivot_columns = ", ".join([f"'{c}', {c}" for c in cols]) sql = f""" - SELECT - '{table_info.catalog}' as catalog, - '{table_info.database}' as database, - '{table_info.table}' as table, - column, - rule_name, - (sum(value) / count(value)) as frequency - FROM - ( - SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) - FROM - ( - SELECT - column, - {matching_string} - FROM ( - SELECT - stack({len(cols)}, {unpivot_columns}) AS (column, value) - FROM {catalog_str}{table_info.database}.{table_info.table} - TABLESAMPLE ({self.sample_size} ROWS) + SELECT + '{table_info.catalog}' as catalog, + '{table_info.database}' as database, + '{table_info.table}' as table, + column, + rule_name, + (sum(value) / count(value)) as frequency + FROM + ( + SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) + FROM + ( + SELECT + column, + {matching_string} + FROM ( + SELECT + stack({len(cols)}, {unpivot_columns}) AS (column, value) + FROM {catalog_str}{table_info.database}.{table_info.table} + TABLESAMPLE ({self.sample_size} ROWS) + ) + ) ) - ) - ) - GROUP BY catalog, database, table, column, rule_name - """ - + GROUP BY catalog, database, table, column, rule_name + """ return strip_margin(sql) diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index a2d14d5..4267e7a 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -200,6 +200,18 @@ def test_scan(spark: SparkSession): ["None", "default", "tb_1", "mac", "ip_v6", 0.0], ["None", "default", "tb_1", "description", "ip_v4", 0.0], ["None", "default", "tb_1", "description", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v6", 0.0], ], columns=["catalog", "database", "table", "column", "rule_name", "frequency"], ) From 264210caa0db9a6d3feb9dba6beceb78e43a35d8 Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Wed, 31 May 2023 15:16:36 +0200 Subject: [PATCH 4/7] scan support for array-type --- discoverx/scanner.py | 56 ++++++++++++++++++++++++++++++-- tests/unit/conftest.py | 3 +- tests/unit/data/columns_mock.csv | 2 +- tests/unit/data/tb_2.json | 6 ++-- tests/unit/scanner_test.py | 4 +++ 5 files changed, 63 insertions(+), 8 deletions(-) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 185bb37..f5e52ce 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -69,7 +69,7 @@ def n_scanned_columns(self) -> int: class Scanner: COLUMNS_TABLE_NAME = "system.information_schema.columns" - COMPLEX_TYPES = {StructType, ArrayType, MapType} + COMPLEX_TYPES = {StructType, ArrayType} def __init__( self, @@ -295,14 +295,22 @@ def _rule_matching_sql(self, table_info: TableInfo): # col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) # cols = [c for c in table_info.columns if c.data_type.lower() == "string"] if len(columns_pdf) == 0: - raise Exception(f"There are no columns of type string to be scanned in {table_info.table}") + raise Exception(f"There are no columns with supported types to be scanned in {table_info.table}") if not expressions: raise Exception(f"There are no rules to scan for.") string_cols = columns_pdf.loc[columns_pdf.type == "string", "col_name"].to_list() - all_sql = self.string_col_sql(string_cols, expressions, table_info) + sql_list = [] + if len(string_cols) > 0: + sql_list.append(self.string_col_sql(string_cols, expressions, table_info)) + + array_cols = columns_pdf.loc[columns_pdf.type == "array", "col_name"].to_list() + if len(array_cols) > 0: + sql_list.append(self.array_col_sql(array_cols, expressions, table_info)) + + all_sql = "\nUNION ALL \n".join(sql_list) return all_sql def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: @@ -342,3 +350,45 @@ def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) - GROUP BY catalog, database, table, column, rule_name """ return strip_margin(sql) + + def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: + catalog_str = f"{table_info.catalog}." if table_info.catalog else "" + matching_columns_sum = [ + f"size(filter(value, x -> x rlike '{r.definition}')) AS `{r.name}_sum`" for r in expressions + ] + matching_columns_count = [ + f"size(value) AS `{r.name}_count`" for r in expressions + ] + matching_columns = matching_columns_sum + matching_columns_count + matching_string = ",\n ".join(matching_columns) + + unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}_sum`, `{r.name}_count`" for r in expressions]) + unpivot_columns = ", ".join([f"'{c}', {c}" for c in cols]) + + sql = f""" + SELECT + '{table_info.catalog}' as catalog, + '{table_info.database}' as database, + '{table_info.table}' as table, + column, + rule_name, + (sum(value_sum) / sum(value_count)) as frequency + FROM + ( + SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count) + FROM + ( + SELECT + column, + {matching_string} + FROM ( + SELECT + stack({len(cols)}, {unpivot_columns}) AS (column, value) + FROM {catalog_str}{table_info.database}.{table_info.table} + TABLESAMPLE ({self.sample_size} ROWS) + ) + ) + ) + GROUP BY catalog, database, table, column, rule_name + """ + return strip_margin(sql) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 116961e..559bcf9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -156,7 +156,8 @@ def sample_datasets(spark: SparkSession, request): .add("street", StringType(), True) .add("town", StringType(), True) .add("postal_number", StringType(), True) - .add("country", StringType(), True), + .add("country", StringType(), True) + .add("ips_used", ArrayType(StringType()), True), True, ) .add("email", StringType()), diff --git a/tests/unit/data/columns_mock.csv b/tests/unit/data/columns_mock.csv index 84fd8d4..6980d41 100644 --- a/tests/unit/data/columns_mock.csv +++ b/tests/unit/data/columns_mock.csv @@ -25,4 +25,4 @@ hive_metastore,default,tb_all_types,str_part_col,STRING,1 ,default,tb_1,description,STRING, ,default,tb_2,active,BOOLEAN, ,default,tb_2,categories,"map", -,default,tb_2,customer,"struct,email:string>,products_owned:array>", +,default,tb_2,customer,"struct>,email:string>,products_owned:array>", diff --git a/tests/unit/data/tb_2.json b/tests/unit/data/tb_2.json index 98f6bf2..7d7222e 100644 --- a/tests/unit/data/tb_2.json +++ b/tests/unit/data/tb_2.json @@ -1,3 +1,3 @@ -{"customer": {"name": "AAA BBBB", "id": 1, "contact": {"address": {"street": "AAA street 11", "town": "AAA town", "postal_number": "111333", "country": "AAA country"}, "email": "aaa.bbb@aaa.com"}, "products_owned": ["product1", "product2", "product10"], "interactions": {"service": "test aaa", "shop": "test shop aaa"}}, "active": true, "categories": {"cat1": "D"}} -{"customer": {"name": "BBB CCCC", "id": 2, "contact": {"address": {"street": "BBB street 12", "town": "BBB town", "postal_number": "111233", "country": "BBB country"}, "email": "bbb.ccc@bbb.com"}, "products_owned": ["product1", "product10"], "interactions": {"service": "test bbb", "request": "test r bbb"}}, "active": false, "categories": {"cat1": "A", "cat2": "B", "cat3": "C"}} -{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country"}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file +{"customer": {"name": "AAA BBBB", "id": 1, "contact": {"address": {"street": "AAA street 11", "town": "AAA town", "postal_number": "111333", "country": "AAA country", "ips_used": []}, "email": "aaa.bbb@aaa.com"}, "products_owned": ["product1", "product2", "product10"], "interactions": {"service": "test aaa", "shop": "test shop aaa"}}, "active": true, "categories": {"cat1": "D"}} +{"customer": {"name": "BBB CCCC", "id": 2, "contact": {"address": {"street": "BBB street 12", "town": "BBB town", "postal_number": "111233", "country": "BBB country", "ips_used": ["102.2.1.1", "103.3.1.1"]}, "email": "bbb.ccc@bbb.com"}, "products_owned": ["product1", "product10"], "interactions": {"service": "test bbb", "request": "test r bbb"}}, "active": false, "categories": {"cat1": "A", "cat2": "B", "cat3": "C"}} +{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country", "ips_used": ["102.1.1.1", "103.1.1.1", "104.1.1.1"]}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index 4267e7a..ac4cd77 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -212,6 +212,10 @@ def test_scan(spark: SparkSession): ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "ip_v6", 0.0], ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v4", 0.0], ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "ip_v4", 1.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "ip_v6", 0.0], ], columns=["catalog", "database", "table", "column", "rule_name", "frequency"], ) From 7739396e259b54122afb63906155a92bcddf2158 Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Wed, 31 May 2023 16:32:46 +0200 Subject: [PATCH 5/7] add data type to scan result --- discoverx/scanner.py | 60 ++++++++++---------------------------- tests/unit/scanner_test.py | 46 ++++++++++++++--------------- 2 files changed, 39 insertions(+), 67 deletions(-) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index f5e52ce..5272043 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -261,52 +261,20 @@ def _rule_matching_sql(self, table_info: TableInfo): if data_type: self.recursive_flatten_complex_type(col.name, data_type, column_list) - columns_pdf = pd.DataFrame(column_list) - # # prepare for cross-join - # columns_pdf["key"] = 0 - # # cross-join - # col_expr_pdf = columns_pdf.merge(expr_pdf, on=["key"]) - # - # def sum_expressions(row): - # if row.type == "string": - # return f"int(regexp_like({row.col_name}, '{row.rule_definition}'))" - # elif row.type == "array": - # return f"size(filter({row.col_name}, x -> x rlike '{row.rule_definition}'))" - # elif row.type == "map_values": - # return f"size(filter(map_values({row.col_name}), x -> x rlike '{row.rule_definition}'))" - # elif row.type == "map_keys": - # return f"size(filter(map_keys({row.col_name}), x -> x rlike '{row.rule_definition}'))" - # else: - # return None - # - # def count_expressions(row): - # if row.type == "string": - # return "1" - # elif row.type == "array": - # return f"size({row.col_name})" - # elif row.type == "map_values": - # return f"size(map_values({row.col_name}))" - # elif row.type == "map_keys": - # return f"size(map_keys({row.col_name}))" - # else: - # return None - # - # col_expr_pdf["sum_expression"] = col_expr_pdf.apply(sum_expressions, axis=1) - # col_expr_pdf["count_expression"] = col_expr_pdf.apply(count_expressions, axis=1) - # cols = [c for c in table_info.columns if c.data_type.lower() == "string"] - if len(columns_pdf) == 0: + + if len(column_list) == 0: raise Exception(f"There are no columns with supported types to be scanned in {table_info.table}") if not expressions: raise Exception(f"There are no rules to scan for.") - string_cols = columns_pdf.loc[columns_pdf.type == "string", "col_name"].to_list() + string_cols = [col for col in column_list if col["type"] == "string"] sql_list = [] if len(string_cols) > 0: sql_list.append(self.string_col_sql(string_cols, expressions, table_info)) - array_cols = columns_pdf.loc[columns_pdf.type == "array", "col_name"].to_list() + array_cols = [col for col in column_list if col["type"] == "array"] if len(array_cols) > 0: sql_list.append(self.array_col_sql(array_cols, expressions, table_info)) @@ -321,7 +289,7 @@ def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) - matching_string = ",\n ".join(matching_columns) unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}`" for r in expressions]) - unpivot_columns = ", ".join([f"'{c}', {c}" for c in cols]) + unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) sql = f""" SELECT @@ -329,25 +297,27 @@ def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) - '{table_info.database}' as database, '{table_info.table}' as table, column, + type, rule_name, (sum(value) / count(value)) as frequency FROM ( - SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) + SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) FROM ( SELECT column, + type, {matching_string} FROM ( SELECT - stack({len(cols)}, {unpivot_columns}) AS (column, value) + stack({len(cols)}, {unpivot_columns}) AS (column, type, value) FROM {catalog_str}{table_info.database}.{table_info.table} TABLESAMPLE ({self.sample_size} ROWS) ) ) ) - GROUP BY catalog, database, table, column, rule_name + GROUP BY catalog, database, table, column, type, rule_name """ return strip_margin(sql) @@ -363,7 +333,7 @@ def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> matching_string = ",\n ".join(matching_columns) unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}_sum`, `{r.name}_count`" for r in expressions]) - unpivot_columns = ", ".join([f"'{c}', {c}" for c in cols]) + unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) sql = f""" SELECT @@ -371,24 +341,26 @@ def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> '{table_info.database}' as database, '{table_info.table}' as table, column, + type, rule_name, (sum(value_sum) / sum(value_count)) as frequency FROM ( - SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count) + SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count) FROM ( SELECT column, + type, {matching_string} FROM ( SELECT - stack({len(cols)}, {unpivot_columns}) AS (column, value) + stack({len(cols)}, {unpivot_columns}) AS (column, type, value) FROM {catalog_str}{table_info.database}.{table_info.table} TABLESAMPLE ({self.sample_size} ROWS) ) ) ) - GROUP BY catalog, database, table, column, rule_name + GROUP BY catalog, database, table, column, type, rule_name """ return strip_margin(sql) diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index ac4cd77..8f3fb86 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -194,30 +194,30 @@ def test_scan_custom_rules(spark: SparkSession): def test_scan(spark: SparkSession): expected = pd.DataFrame( [ - ["None", "default", "tb_1", "ip", "ip_v4", 1.0], - ["None", "default", "tb_1", "ip", "ip_v6", 0.0], - ["None", "default", "tb_1", "mac", "ip_v4", 0.0], - ["None", "default", "tb_1", "mac", "ip_v6", 0.0], - ["None", "default", "tb_1", "description", "ip_v4", 0.0], - ["None", "default", "tb_1", "description", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`name`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`name`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`email`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "ip_v4", 1.0], - ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "ip_v6", 0.0], - ["None", "default", "tb_2", "`customer`.`products_owned`", "ip_v4", 0.0], - ["None", "default", "tb_2", "`customer`.`products_owned`", "ip_v6", 0.0], + ["None", "default", "tb_1", "ip", "string", "ip_v4", 1.0], + ["None", "default", "tb_1", "ip", "string", "ip_v6", 0.0], + ["None", "default", "tb_1", "mac", "string", "ip_v4", 0.0], + ["None", "default", "tb_1", "mac", "string", "ip_v6", 0.0], + ["None", "default", "tb_1", "description", "string", "ip_v4", 0.0], + ["None", "default", "tb_1", "description", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "array", "ip_v4", 1.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "array", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "array", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "array", "ip_v6", 0.0], ], - columns=["catalog", "database", "table", "column", "rule_name", "frequency"], + columns=["catalog", "database", "table", "column", "type", "rule_name", "frequency"], ) rules = Rules() From 524af432095a42471674641aa5cd9e9e978d9f56 Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Wed, 31 May 2023 17:10:49 +0200 Subject: [PATCH 6/7] add data type to classification result --- discoverx/classification.py | 13 +++++++------ tests/unit/classification_test.py | 25 ++++++++++++++++++++++++- tests/unit/conftest.py | 2 +- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/discoverx/classification.py b/discoverx/classification.py index 5c32632..f80628b 100644 --- a/discoverx/classification.py +++ b/discoverx/classification.py @@ -43,6 +43,7 @@ def above_threshold(self): "database": "table_schema", "table": "table_name", "column": "column_name", + "type": "data_type", "rule_name": "tag_name", } ) @@ -77,7 +78,7 @@ def aggregate_updates(pdf): return pd.DataFrame(output) - self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_4"]) + self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name", "data_type"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_5"]) # when testing we don't have a 3-level namespace but we need # to make sure we get None instead of NaN self.classification_result.table_catalog = self.classification_result.table_catalog.astype(object) @@ -94,7 +95,7 @@ def _get_classification_table_from_delta(self): self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {catalog + '.' + schema}") self.spark.sql( f""" - CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) + CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) """ ) logger.friendly(f"The classification table {self.classification_table_name} has been created.") @@ -154,8 +155,7 @@ def _stage_updates(self, input_classification_pdf: pd.DataFrame): classification_pdf["to_be_set"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) - set(x["Current Tags"])), axis=1) classification_pdf["to_be_kept"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) & set(x["Current Tags"])), axis=1) - self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) - + self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name", "data_type"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) def inspect(self): self.inspection_tool = InspectionTool(self.classification_result, self.publish) @@ -169,14 +169,14 @@ def publish(self, publish_uc_tags: bool): staged_updates_df = self.spark.createDataFrame( self.staged_updates, - "table_catalog: string, table_schema: string, table_name: string, column_name: string, action: string, tag_name: string", + "table_catalog: string, table_schema: string, table_name: string, column_name: string, data_type: string, action: string, tag_name: string", ).withColumn("effective_timestamp", func.current_timestamp()) # merge using scd-typ2 logger.friendly(f"Update classification table {self.classification_table_name}") self.classification_table.alias("target").merge( staged_updates_df.alias("source"), - "target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.tag_name = source.tag_name AND target.current = true", + "target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.data_type = source.data_type AND target.tag_name = source.tag_name AND target.current = true", ).whenMatchedUpdate( condition = "source.action = 'to_be_unset'", set={"current": "false", "end_timestamp": "source.effective_timestamp"} @@ -186,6 +186,7 @@ def publish(self, publish_uc_tags: bool): "table_schema": "source.table_schema", "table_name": "source.table_name", "column_name": "source.column_name", + "data_type": "source.data_type", "tag_name": "source.tag_name", "effective_timestamp": "source.effective_timestamp", "current": "true", diff --git a/tests/unit/classification_test.py b/tests/unit/classification_test.py index d4aee39..25dfd25 100644 --- a/tests/unit/classification_test.py +++ b/tests/unit/classification_test.py @@ -1,7 +1,6 @@ import pandas as pd from pandas.testing import assert_frame_equal import pytest -import numpy as np from discoverx.dx import DX from discoverx.dx import Scanner @@ -32,6 +31,7 @@ def test_classifier(spark): ], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip", "mac", "mac", "mac", "description", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "mac", "ip_v4", "ip_v6", "mac", "ip_v4", "ip_v6", "mac"], "frequency": [1.0, 0.0, 0.0, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0], } @@ -48,6 +48,7 @@ def test_classifier(spark): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "mac"], + "data_type": ["string", "string"], "Current Tags": [[], []], "Detected Tags": [["ip_v4"], ["mac"]], "Tags to be published": [["ip_v4"], ["mac"]], @@ -74,6 +75,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], } @@ -90,6 +92,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default"], "table_name": ["tb_1"], "column_name": ["ip"], + "data_type": ["string"], "tag_name": ["ip_v4"], "effective_timestamp": [pd.Timestamp(2023, 1, 1, 0)], "current": [True], @@ -112,6 +115,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default"], "table_name": ["tb_1"], "column_name": ["ip"], + "data_type": ["string"], "tag_name": ["ip_v4"], "effective_timestamp": [pd.Timestamp(2023, 1, 1, 0)], "current": [True], @@ -129,6 +133,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip6", "ip6", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [1.0, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0, 0.0], } @@ -147,6 +152,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "ip6"], + "data_type": ["string", "string"], "tag_name": ["ip_v4", "ip_v6"], "effective_timestamp": [current_time, current_time], "current": [True, True], @@ -169,6 +175,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip6", "ip6", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [0.7, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0, 0.0], } @@ -186,6 +193,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "ip6"], + "data_type": ["string", "string"], "tag_name": ["ip_v4", "ip_v6"], "effective_timestamp": [current_time, current_time], "current": [True, True], @@ -235,6 +243,20 @@ def test_merging_scan_results(spark, mock_current_time): "description", "description", ], + "type": [ + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + ], "rule_name": [ "ip_v4", "ip_v6", @@ -272,6 +294,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default", "default", "default"], "table_name": ["tb_1", "tb_1", "tb_1", "tb_2"], "column_name": ["ip", "ip6", "ip6", "mac"], + "data_type": ["string", "string", "string", "string"], "tag_name": ["ip_v4", "ip_v6", "pii", "mac"], "effective_timestamp": [current_time, current_time, current_time, current_time], "current": [False, True, True, True], diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 559bcf9..9d5a659 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -253,7 +253,7 @@ def get_classification_table_mock(self): self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}") self.spark.sql( f""" - CREATE TABLE IF NOT EXISTS {schema + '.' + table} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) USING DELTA + CREATE TABLE IF NOT EXISTS {schema + '.' + table} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) USING DELTA """ ) return DeltaTable.forName(self.spark, self.classification_table_name) From 448503848b6074b9a9c02d62d51ad12c18b607be Mon Sep 17 00:00:00 2001 From: "david.tempelmann" Date: Thu, 1 Jun 2023 11:23:36 +0200 Subject: [PATCH 7/7] add complex type support for search --- discoverx/dx.py | 1 + discoverx/msql.py | 17 ++++++-- discoverx/scanner.py | 3 +- tests/unit/conftest.py | 1 + tests/unit/data/tb_2.json | 2 +- tests/unit/dx_test.py | 83 ++++++++++++++++++++++++--------------- 6 files changed, 70 insertions(+), 37 deletions(-) diff --git a/discoverx/dx.py b/discoverx/dx.py index d23eee6..d9e5ef6 100644 --- a/discoverx/dx.py +++ b/discoverx/dx.py @@ -323,6 +323,7 @@ def _msql(self, msql: str, what_if: bool = False): func.col("table_schema").alias("database"), func.col("table_name").alias("table"), func.col("column_name").alias("column"), + "data_type", "tag_name", ).toPandas() ) diff --git a/discoverx/msql.py b/discoverx/msql.py index 9342c98..193b680 100644 --- a/discoverx/msql.py +++ b/discoverx/msql.py @@ -66,6 +66,15 @@ def compile_msql(self, table_info: TableInfo) -> list[SQLRow]: temp_sql = msql for tagged_col in tagged_cols: temp_sql = temp_sql.replace(f"[{tagged_col.tag}]", tagged_col.name) + # TODO: Can we avoid "replacing strings" for the different types in the future? This is due to the generation of MSQL. Maybe we should rather generate SQL directly from the search method... + if tagged_col.data_type == "array": + # return a string of the array as value to be able to union later + temp_sql = re.sub("(.*\'value\', )([^)]+)(\).*)", f"\g<1> array_join({tagged_col.name}, ', ') \g<3>", temp_sql) + # modify the WHERE condition to work with arrays + split_cond_sql = temp_sql.split("WHERE") + if len(split_cond_sql) > 1: + temp_sql = split_cond_sql[0] + "WHERE " + f"array_contains({tagged_col.name},{split_cond_sql[1].split('=')[1]})" + sql_statements.append(SQLRow(table_info.catalog, table_info.database, table_info.table, temp_sql)) return sql_statements @@ -77,9 +86,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: classified_cols = classified_result_pdf.copy() classified_cols = classified_cols[classified_cols['tag_name'].isin(self.tags)] - classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column']).aggregate(lambda x: list(x))[['tag_name']].reset_index() + classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column', 'data_type']).aggregate(lambda x: list(x))[['tag_name']].reset_index() - classified_cols['col_tags'] = classified_cols[['column', 'tag_name']].apply(tuple, axis=1) + classified_cols['col_tags'] = classified_cols[['column', 'data_type', 'tag_name']].apply(tuple, axis=1) df = classified_cols.groupby(['catalog', 'database', 'table']).aggregate(lambda x: list(x))[['col_tags']].reset_index() # Filter tables by matching filter @@ -91,9 +100,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: [ ColumnInfo( col[0], # col name - "", # TODO + col[1], # data type None, # TODO - col[1] # Tags + col[2] # Tags ) for col in row[3] ] ) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.databases) and fnmatch(row[2], self.tables)] diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 5272043..8c60796 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -29,12 +29,13 @@ class TableInfo: columns: list[ColumnInfo] def get_columns_by_tag(self, tag: str): - return [TaggedColumn(col.name, tag) for col in self.columns if tag in col.tags] + return [TaggedColumn(col.name, col.data_type, tag) for col in self.columns if tag in col.tags] @dataclass class TaggedColumn: name: str + data_type: str tag: str diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9d5a659..27c678f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -190,6 +190,7 @@ def sample_datasets(spark: SparkSession, request): logging.info("Test session finished, removing sample datasets") spark.sql("DROP TABLE IF EXISTS default.tb_1") + spark.sql("DROP TABLE IF EXISTS default.tb_2") spark.sql("DROP TABLE IF EXISTS default.columns_mock") if Path(warehouse_dir).exists(): shutil.rmtree(warehouse_dir) diff --git a/tests/unit/data/tb_2.json b/tests/unit/data/tb_2.json index 7d7222e..384b73f 100644 --- a/tests/unit/data/tb_2.json +++ b/tests/unit/data/tb_2.json @@ -1,3 +1,3 @@ {"customer": {"name": "AAA BBBB", "id": 1, "contact": {"address": {"street": "AAA street 11", "town": "AAA town", "postal_number": "111333", "country": "AAA country", "ips_used": []}, "email": "aaa.bbb@aaa.com"}, "products_owned": ["product1", "product2", "product10"], "interactions": {"service": "test aaa", "shop": "test shop aaa"}}, "active": true, "categories": {"cat1": "D"}} {"customer": {"name": "BBB CCCC", "id": 2, "contact": {"address": {"street": "BBB street 12", "town": "BBB town", "postal_number": "111233", "country": "BBB country", "ips_used": ["102.2.1.1", "103.3.1.1"]}, "email": "bbb.ccc@bbb.com"}, "products_owned": ["product1", "product10"], "interactions": {"service": "test bbb", "request": "test r bbb"}}, "active": false, "categories": {"cat1": "A", "cat2": "B", "cat3": "C"}} -{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country", "ips_used": ["102.1.1.1", "103.1.1.1", "104.1.1.1"]}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file +{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country", "ips_used": ["102.1.1.1", "1.2.3.4", "104.1.1.1"]}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file diff --git a/tests/unit/dx_test.py b/tests/unit/dx_test.py index c3b17f1..ef47221 100644 --- a/tests/unit/dx_test.py +++ b/tests/unit/dx_test.py @@ -7,9 +7,9 @@ @pytest.fixture(scope="module", name="dx_ip") -def scan_ip_in_tb1(spark, mock_uc_functionality): +def scan_ip_in_tb(spark, mock_uc_functionality): dx = DX(spark=spark, classification_table_name="_discoverx.tags") - dx.scan(from_tables="*.*.tb_1", rules="ip_*") + dx.scan(from_tables="*.*.tb_*", rules="ip_*") dx.publish() yield dx @@ -50,87 +50,108 @@ def test_scan_and_msql(spark, dx_ip): except Exception as e: pytest.fail(f"Test failed with exception {e}") + def test_search(spark, dx_ip: DX): # search a specific term and auto-detect matching tags/rules result = dx_ip.search("1.2.3.4").collect() - assert result[0].table == 'tb_1' - assert result[0].search_result.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].search_result.ip_v4.column == "ip" + assert result[1].table == "tb_2" + assert result[1].search_result.ip_v4.column == "`customer`.`contact`.`address`.`ips_used`" # search all records for specific tag - result_tags_only = dx_ip.search(by_tags='ip_v4') - assert {row.search_result.ip_v4.value for row in result_tags_only.collect()} == {"1.2.3.4", "3.4.5.60"} + result_tags_only = dx_ip.search(by_tags="ip_v4") + assert {row.search_result.ip_v4.value for row in result_tags_only.collect()} == { + "", + "1.2.3.4", + "102.1.1.1, 1.2.3.4, 104.1.1.1", + "102.2.1.1, 103.3.1.1", + "3.4.5.60", + } # specify catalog, database and table - result_tags_namespace = dx_ip.search(by_tags='ip_v4', from_tables="*.default.tb_*") + result_tags_namespace = dx_ip.search(by_tags="ip_v4", from_tables="*.default.tb_1") assert {row.search_result.ip_v4.value for row in result_tags_namespace.collect()} == {"1.2.3.4", "3.4.5.60"} # search specific term for list of specified tags - result_term_tag = dx_ip.search(search_term="3.4.5.60", by_tags=['ip_v4']).collect() - assert result_term_tag[0].table == 'tb_1' + result_term_tag = dx_ip.search(search_term="3.4.5.60", by_tags=["ip_v4"]).collect() + assert result_term_tag[0].table == "tb_1" assert result_term_tag[0].search_result.ip_v4.value == "3.4.5.60" with pytest.raises(ValueError) as no_tags_no_terms_error: dx_ip.search() - assert no_tags_no_terms_error.value.args[0] == "Neither search_term nor by_tags have been provided. At least one of them need to be specified." + assert ( + no_tags_no_terms_error.value.args[0] + == "Neither search_term nor by_tags have been provided. At least one of them need to be specified." + ) with pytest.raises(ValueError) as list_with_ints: - dx_ip.search(by_tags=[1, 3, 'ip']) - assert list_with_ints.value.args[0] == "The provided by_tags [1, 3, 'ip'] have the wrong type. Please provide either a str or List[str]." + dx_ip.search(by_tags=[1, 3, "ip"]) + assert ( + list_with_ints.value.args[0] + == "The provided by_tags [1, 3, 'ip'] have the wrong type. Please provide either a str or List[str]." + ) with pytest.raises(ValueError) as single_bool: dx_ip.search(by_tags=True) - assert single_bool.value.args[0] == "The provided by_tags True have the wrong type. Please provide either a str or List[str]." + assert ( + single_bool.value.args[0] + == "The provided by_tags True have the wrong type. Please provide either a str or List[str]." + ) def test_select_by_tag(spark, dx_ip): # search a specific term and auto-detect matching tags/rules result = dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags="ip_v4").collect() - assert result[0].table == 'tb_1' - assert result[0].tagged_columns.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].tagged_columns.ip_v4.column == "ip" result = dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=["ip_v4"]).collect() - assert result[0].table == 'tb_1' - assert result[0].tagged_columns.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].tagged_columns.ip_v4.column == "ip" with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="*.default.tb_*") - + with pytest.raises(ValueError): - dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=[1, 3, 'ip']) - + dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=[1, 3, "ip"]) + with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=True) with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="invalid from", by_tags="email") - + + # @pytest.mark.skip(reason="Delete is only working with v2 tables. Needs investigation") def test_delete_by_tag(spark, dx_ip): # search a specific term and auto-detect matching tags/rules result = dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9") - assert result is None # Nothing should be executed + assert result is None # Nothing should be executed - result = dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9", yes_i_am_sure=True).collect() - assert result[0].table == 'tb_1' + result = dx_ip.delete_by_tag( + from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9", yes_i_am_sure=True + ).collect() + assert result[0].table == "tb_1" with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="x") with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", values="x") - + with pytest.raises(ValueError): - dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=['ip'], values="x") - + dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=["ip"], values="x") + with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=True, values="x") with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="invalid from", by_tag="email", values="x") - + # test multiple tags def test_search_multiple(spark, mock_uc_functionality): @@ -140,8 +161,8 @@ def test_search_multiple(spark, mock_uc_functionality): # search a specific term and auto-detect matching tags/rules result = dx.search(by_tags=["ip_v4", "mac"]) - assert result.collect()[0].table == 'tb_1' - assert result.collect()[0].search_result.ip_v4.column == 'ip' - assert result.collect()[0].search_result.mac.column == 'mac' + assert result.collect()[0].table == "tb_1" + assert result.collect()[0].search_result.ip_v4.column == "ip" + assert result.collect()[0].search_result.mac.column == "mac" spark.sql("DROP TABLE IF EXISTS _discoverx.tags")