diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index f4b8e8c..8867091 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -35,7 +35,6 @@ class DataFrameProhibitedColumnError(ValueError): def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]: """Validate the presence of column names in a DataFrame. - :param df: A spark DataFrame. :type df: DataFrame :param required_col_names: List of the required column names for the DataFrame. @@ -48,13 +47,13 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], r """ all_col_names = df.columns missing_col_names = [x for x in required_col_names if x not in all_col_names] - + if missing_col_names: error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}" if return_bool: return False raise DataFrameMissingColumnError(error_message) - + return True if return_bool else None @@ -65,7 +64,6 @@ def validate_schema( return_bool: bool = False, ) -> Union[None, bool]: """Function that validate if a given DataFrame has a given StructType as its schema. - :param df: DataFrame to validate :type df: DataFrame :param required_schema: StructType required for the DataFrame @@ -90,19 +88,20 @@ def validate_schema( x.nullable = None missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields] - + if missing_struct_fields: - error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" + error_message = ( + f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" + ) if return_bool: return False raise DataFrameMissingStructFieldError(error_message) - + return True if return_bool else None def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]: """Validate that none of the prohibited column names are present among specified DataFrame columns. - :param df: DataFrame containing columns to be checked. :param prohibited_col_names: List of prohibited column names. :param return_bool: If True, return a boolean instead of raising an exception. @@ -113,11 +112,11 @@ def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], """ all_col_names = df.columns extra_col_names = [x for x in all_col_names if x in prohibited_col_names] - + if extra_col_names: error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}" if return_bool: return False raise DataFrameProhibitedColumnError(error_message) - - return True if return_bool else None \ No newline at end of file + + return True if return_bool else None diff --git a/tests/test_dataframe_validator.py b/tests/test_dataframe_validator.py index b99e3c7..4debd01 100644 --- a/tests/test_dataframe_validator.py +++ b/tests/test_dataframe_validator.py @@ -1,7 +1,6 @@ import pytest from pyspark.sql.types import StructType, StructField, StringType, LongType import semver - import quinn from .spark import spark @@ -21,7 +20,7 @@ def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) quinn.validate_presence_of_columns(source_df, ["name"], False) - + def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) @@ -66,7 +65,7 @@ def it_does_nothing_when_the_schema_matches_and_return_bool_is_false(): ] ) quinn.validate_schema(source_df, required_schema, return_bool = False) - + def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) @@ -118,7 +117,7 @@ def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_fal data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) quinn.validate_absence_of_columns(source_df, ["favorite_color"], False) - + def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"])