Skip to content

Commit

Permalink
Merge pull request #267 from paulooctavio/feature/Update-validate-met…
Browse files Browse the repository at this point in the history
…hods-to-return-boolean-fix-linting

Added new parameter return_bool to validate dataframe methods (fix linting)
  • Loading branch information
jeffbrennan authored Oct 4, 2024
2 parents a0849f3 + ab90811 commit 9156cee
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
21 changes: 10 additions & 11 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DataFrameProhibitedColumnError(ValueError):

def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate the presence of column names in a DataFrame.
:param df: A spark DataFrame.
:type df: DataFrame
:param required_col_names: List of the required column names for the DataFrame.
Expand All @@ -48,13 +47,13 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], r
"""
all_col_names = df.columns
missing_col_names = [x for x in required_col_names if x not in all_col_names]

if missing_col_names:
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameMissingColumnError(error_message)

return True if return_bool else None


Expand All @@ -65,7 +64,6 @@ def validate_schema(
return_bool: bool = False,
) -> Union[None, bool]:
"""Function that validate if a given DataFrame has a given StructType as its schema.
:param df: DataFrame to validate
:type df: DataFrame
:param required_schema: StructType required for the DataFrame
Expand All @@ -90,19 +88,20 @@ def validate_schema(
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]

if missing_struct_fields:
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
error_message = (
f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
)
if return_bool:
return False
raise DataFrameMissingStructFieldError(error_message)

return True if return_bool else None


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate that none of the prohibited column names are present among specified DataFrame columns.
:param df: DataFrame containing columns to be checked.
:param prohibited_col_names: List of prohibited column names.
:param return_bool: If True, return a boolean instead of raising an exception.
Expand All @@ -113,11 +112,11 @@ def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str],
"""
all_col_names = df.columns
extra_col_names = [x for x in all_col_names if x in prohibited_col_names]

if extra_col_names:
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameProhibitedColumnError(error_message)
return True if return_bool else None

return True if return_bool else None
7 changes: 3 additions & 4 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from pyspark.sql.types import StructType, StructField, StringType, LongType
import semver

import quinn
from .spark import spark

Expand All @@ -21,7 +20,7 @@ def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_presence_of_columns(source_df, ["name"], False)

def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down Expand Up @@ -66,7 +65,7 @@ def it_does_nothing_when_the_schema_matches_and_return_bool_is_false():
]
)
quinn.validate_schema(source_df, required_schema, return_bool = False)

def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down Expand Up @@ -118,7 +117,7 @@ def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_fal
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"], False)

def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down

0 comments on commit 9156cee

Please sign in to comment.