From 04571ef81045e32f850e61c8947fa8957351ea83 Mon Sep 17 00:00:00 2001 From: Puneet Sharma <65413434+puneetsharma04@users.noreply.github.com> Date: Sun, 9 Apr 2023 00:29:39 +0200 Subject: [PATCH] Append schema functionality for review & comments. (#86) * Added files for schema append functionality * Update test_append_if_schema_identical.py * Made the changes as per the review comments * Made the changes as per the review comments & added comments for better readability. * Made the changes as per the review comments & added comments for better readability. --- quinn/__init__.py | 1 + quinn/append_if_schema_identical.py | 37 ++++++++++++++++++++++++ tests/test_append_if_schema_identical.py | 24 +++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 quinn/append_if_schema_identical.py create mode 100644 tests/test_append_if_schema_identical.py diff --git a/quinn/__init__.py b/quinn/__init__.py index dbd5ee1c..2383f267 100644 --- a/quinn/__init__.py +++ b/quinn/__init__.py @@ -5,3 +5,4 @@ from .functions import * from .scala_to_pyspark import ScalaToPyspark from .transformations import * +from .append_if_schema_identical import append_if_schema_identical diff --git a/quinn/append_if_schema_identical.py b/quinn/append_if_schema_identical.py new file mode 100644 index 00000000..c6121cfb --- /dev/null +++ b/quinn/append_if_schema_identical.py @@ -0,0 +1,37 @@ +from pyspark.sql import DataFrame + + +class SchemaMismatchError(ValueError): + """raise this when there's a schema mismatch between source & target schema""" + + +def append_if_schema_identical(source_df: DataFrame, target_df: DataFrame) -> DataFrame: + """Compares the schema of source & target dataframe . + :param source_df: Input DataFrame + :type source_df: pyspark.sql.DataFrame + :param target_df: Input DataFrame + :type target_df: pyspark.sql.DataFrame + :return: dataframe + :rtype: pyspark.sql.DataFrame + """ + # Retrieve the schemas of the source and target dataframes + source_schema = source_df.schema + target_schema = target_df.schema + + # Convert the schemas to a list of tuples + source_schema_list = [(field.name, str(field.dataType)) for field in source_schema] + target_schema_list = [(field.name, str(field.dataType)) for field in target_schema] + + unmatched_cols = [col for col in source_schema_list if col not in target_schema_list] + error_message = f"The schemas of the source and target dataframes are not identical." \ + f"From source schema column {unmatched_cols} is missing in target schema" + # Check if the column names in the source and target schemas are the same, regardless of their order + if set(source_schema.fieldNames()) != set(target_schema.fieldNames()): + raise SchemaMismatchError(error_message) + # Check if the column names and data types in the source and target schemas are the same, in the same order + if sorted(source_schema_list) != sorted(target_schema_list): + raise SchemaMismatchError(error_message) + + # Append the dataframes if the schemas are identical + appended_df = target_df.unionByName(source_df) + return appended_df diff --git a/tests/test_append_if_schema_identical.py b/tests/test_append_if_schema_identical.py new file mode 100644 index 00000000..67034e0d --- /dev/null +++ b/tests/test_append_if_schema_identical.py @@ -0,0 +1,24 @@ +from pyspark.sql.types import StructType, StructField, IntegerType, StringType +import quinn +from tests.conftest import auto_inject_fixtures + + +@auto_inject_fixtures("spark") +def test_append_if_schema_identical(spark): + source_data = [(1, "capetown", "Alice"), (2, "delhi", "Bob")] + target_data = [(3, "Charlie", "New York"), (4, "Dave", "Los Angeles")] + + source_df = spark.createDataFrame(source_data, schema=StructType([ + StructField("id", IntegerType()), + StructField("city", StringType()), + StructField("name", StringType()) + ])) + + target_df = spark.createDataFrame(target_data, schema=StructType([ + StructField("id", IntegerType()), + StructField("name", StringType()), + StructField("city", StringType()) + ])) + + # Call the append_if_schema_identical function + appended_df = quinn.append_if_schema_identical(source_df, target_df)