Skip to content

Commit

Permalink
chore(ingest): start working on pydantic v2 support (#9220)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Nov 10, 2023
1 parent b851d59 commit 89dff8f
Show file tree
Hide file tree
Showing 44 changed files with 216 additions and 139 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/scripts/avro_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def add_avro_python3_warning(filepath: Path) -> None:
# This means that installation order matters, which is a pretty unintuitive outcome.
# See https://github.com/pypa/pip/issues/4625 for details.
try:
from avro.schema import SchemaFromJSONData
from avro.schema import SchemaFromJSONData # type: ignore
import warnings
warnings.warn("It seems like 'avro-python3' is installed, which conflicts with the 'avro' package used by datahub. "
Expand Down
1 change: 1 addition & 0 deletions metadata-ingestion/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ filterwarnings =
ignore:Deprecated call to \`pkg_resources.declare_namespace:DeprecationWarning
ignore:pkg_resources is deprecated as an API:DeprecationWarning
ignore:Did not recognize type:sqlalchemy.exc.SAWarning
ignore::datahub.configuration.pydantic_migration_helpers.PydanticDeprecatedSince20

[coverage:run]
# Because of some quirks in the way setup.cfg, coverage.py, pytest-cov,
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/cli/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

class GmsConfig(BaseModel):
server: str
token: Optional[str]
token: Optional[str] = None


class DatahubConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/cli/lite_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DuckDBLiteConfigWrapper(DuckDBLiteConfig):

class LiteCliConfig(DatahubConfig):
lite: LiteLocalConfig = LiteLocalConfig(
type="duckdb", config=DuckDBLiteConfigWrapper()
type="duckdb", config=DuckDBLiteConfigWrapper().dict()
)


Expand Down
26 changes: 21 additions & 5 deletions metadata-ingestion/src/datahub/configuration/_config_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pydantic.types
import pydantic.validators

from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2


class ConfigEnum(Enum):
# Ideally we would use @staticmethod here, but some versions of Python don't support it.
Expand All @@ -15,11 +17,25 @@ def _generate_next_value_( # type: ignore
# From https://stackoverflow.com/a/44785241/5004662.
return name

@classmethod
def __get_validators__(cls) -> "pydantic.types.CallableGenerator":
# We convert the text to uppercase before attempting to match it to an enum value.
yield cls.validate
yield pydantic.validators.enum_member_validator
if PYDANTIC_VERSION_2:
# if TYPE_CHECKING:
# from pydantic import GetCoreSchemaHandler

@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler): # type: ignore
from pydantic_core import core_schema

return core_schema.no_info_before_validator_function(
cls.validate, handler(source_type)
)

else:

@classmethod
def __get_validators__(cls) -> "pydantic.types.CallableGenerator":
# We convert the text to uppercase before attempting to match it to an enum value.
yield cls.validate
yield pydantic.validators.enum_member_validator

@classmethod
def validate(cls, v): # type: ignore[no-untyped-def]
Expand Down
27 changes: 19 additions & 8 deletions metadata-ingestion/src/datahub/configuration/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Protocol, runtime_checkable

from datahub.configuration._config_enum import ConfigEnum
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
from datahub.utilities.dedup_list import deduplicate_list

_ConfigSelf = TypeVar("_ConfigSelf", bound="ConfigModel")
Expand Down Expand Up @@ -71,14 +72,8 @@ def redact_raw_config(obj: Any) -> Any:

class ConfigModel(BaseModel):
class Config:
extra = Extra.forbid
underscore_attrs_are_private = True
keep_untouched = (
cached_property,
) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info.

@staticmethod
def schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
# We use the custom "hidden_from_docs" attribute to hide fields from the
# autogenerated docs.
remove_fields = []
Expand All @@ -89,6 +84,19 @@ def schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None:
for key in remove_fields:
del schema["properties"][key]

# This is purely to suppress pydantic's warnings, since this class is used everywhere.
if PYDANTIC_VERSION_2:
extra = "forbid"
ignored_types = (cached_property,)
json_schema_extra = _schema_extra
else:
extra = Extra.forbid
underscore_attrs_are_private = True
keep_untouched = (
cached_property,
) # needed to allow cached_property to work. See https://github.com/samuelcolvin/pydantic/issues/1241 for more info.
schema_extra = _schema_extra

@classmethod
def parse_obj_allow_extras(cls: Type[_ConfigSelf], obj: Any) -> _ConfigSelf:
with unittest.mock.patch.object(cls.Config, "extra", pydantic.Extra.allow):
Expand All @@ -102,7 +110,10 @@ class PermissiveConfigModel(ConfigModel):
# It is usually used for argument bags that are passed through to third-party libraries.

class Config:
extra = Extra.allow
if PYDANTIC_VERSION_2:
extra = "allow"
else:
extra = Extra.allow


class TransformerSemantics(ConfigEnum):
Expand Down
6 changes: 3 additions & 3 deletions metadata-ingestion/src/datahub/configuration/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class OAuthConfiguration(ConfigModel):
default=False,
)
client_secret: Optional[SecretStr] = Field(
description="client secret of the application if use_certificate = false"
None, description="client secret of the application if use_certificate = false"
)
encoded_oauth_public_key: Optional[str] = Field(
description="base64 encoded certificate content if use_certificate = true"
None, description="base64 encoded certificate content if use_certificate = true"
)
encoded_oauth_private_key: Optional[str] = Field(
description="base64 encoded private key content if use_certificate = true"
None, description="base64 encoded private key content if use_certificate = true"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pydantic.version
from packaging.version import Version

PYDANTIC_VERSION_2: bool
if Version(pydantic.version.VERSION) >= Version("2.0"):
PYDANTIC_VERSION_2 = True
else:
PYDANTIC_VERSION_2 = False


# This can be used to silence deprecation warnings while we migrate.
if PYDANTIC_VERSION_2:
from pydantic import PydanticDeprecatedSince20 # type: ignore
else:

class PydanticDeprecatedSince20(Warning): # type: ignore
pass


if PYDANTIC_VERSION_2:
from pydantic import BaseModel as GenericModel
else:
from pydantic.generics import GenericModel # type: ignore


__all__ = [
"PYDANTIC_VERSION_2",
"PydanticDeprecatedSince20",
"GenericModel",
]
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/emitter/mcp_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class BucketKey(ContainerKey):
class NotebookKey(DatahubKey):
notebook_id: int
platform: str
instance: Optional[str]
instance: Optional[str] = None

def as_urn(self) -> str:
return make_dataset_urn_with_platform_instance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def _try_reformat_with_black(code: str) -> str:


class WorkUnitRecordExtractorConfig(ConfigModel):
set_system_metadata = True
set_system_metadata_pipeline_name = (
set_system_metadata: bool = True
set_system_metadata_pipeline_name: bool = (
False # false for now until the models are available in OSS
)
unpack_mces_into_mcps = False
unpack_mces_into_mcps: bool = False


class WorkUnitRecordExtractor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class PipelineConfig(ConfigModel):

source: SourceConfig
sink: DynamicTypedConfig
transformers: Optional[List[DynamicTypedConfig]]
transformers: Optional[List[DynamicTypedConfig]] = None
flags: FlagsConfig = Field(default=FlagsConfig(), hidden_from_docs=True)
reporting: List[ReporterConfig] = []
run_id: str = DEFAULT_RUN_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def validate_column_lineage(cls, v: bool, values: Dict[str, Any]) -> bool:
description="Option to exclude empty projects from being ingested.",
)

@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def profile_default_settings(cls, values: Dict) -> Dict:
# Extra default SQLAlchemy option for better connection pooling and threading.
# https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def glob_include(self):
logger.debug(f"Setting _glob_include: {glob_include}")
return glob_include

@pydantic.root_validator()
@pydantic.root_validator(skip_on_failure=True)
def validate_path_spec(cls, values: Dict) -> Dict[str, Any]:
# validate that main fields are populated
required_fields = ["include", "file_types", "default_extension"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
hidden_from_docs=True,
)

@root_validator
@root_validator(skip_on_failure=True)
def check_ingesting_data(cls, values):
if (
not values.get("database_connection")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DBTCloudConfig(DBTCommonConfig):
description="The ID of the job to ingest metadata from.",
)
run_id: Optional[int] = Field(
None,
description="The ID of the run to ingest metadata from. If not specified, we'll default to the latest run.",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class DBTEntitiesEnabled(ConfigModel):
description="Emit metadata for test results when set to Yes or Only",
)

@root_validator
@root_validator(skip_on_failure=True)
def process_only_directive(cls, values):
# Checks that at most one is set to ONLY, and then sets the others to NO.

Expand Down Expand Up @@ -229,15 +229,15 @@ class DBTCommonConfig(
default={},
description="mapping rules that will be executed against dbt column meta properties. Refer to the section below on dbt meta automated mappings.",
)
enable_meta_mapping = Field(
enable_meta_mapping: bool = Field(
default=True,
description="When enabled, applies the mappings that are defined through the meta_mapping directives.",
)
query_tag_mapping: Dict = Field(
default={},
description="mapping rules that will be executed against dbt query_tag meta properties. Refer to the section below on dbt meta automated mappings.",
)
enable_query_tag_mapping = Field(
enable_query_tag_mapping: bool = Field(
default=True,
description="When enabled, applies the mappings that are defined through the `query_tag_mapping` directives.",
)
Expand Down
4 changes: 2 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ class KafkaSourceConfig(
default="datahub.ingestion.source.confluent_schema_registry.ConfluentSchemaRegistry",
description="The fully qualified implementation class(custom) that implements the KafkaSchemaRegistryBase interface.",
)
schema_tags_field = pydantic.Field(
schema_tags_field: str = pydantic.Field(
default="tags",
description="The field name in the schema metadata that contains the tags to be added to the dataset.",
)
enable_meta_mapping = pydantic.Field(
enable_meta_mapping: bool = pydantic.Field(
default=True,
description="When enabled, applies the mappings that are defined through the meta_mapping directives.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def convert_string_to_connection_def(cls, conn_map):
)
return conn_map

@root_validator()
@root_validator(skip_on_failure=True)
def check_either_connection_map_or_connection_provided(cls, values):
"""Validate that we must either have a connection map or an api credential"""
if not values.get("connection_to_platform_map", {}) and not values.get(
Expand All @@ -286,7 +286,7 @@ def check_either_connection_map_or_connection_provided(cls, values):
)
return values

@root_validator()
@root_validator(skip_on_failure=True)
def check_either_project_name_or_api_provided(cls, values):
"""Validate that we must either have a project name or an api credential to fetch project names"""
if not values.get("project_name") and not values.get("api"):
Expand Down Expand Up @@ -1070,7 +1070,6 @@ def _get_fields(
def determine_view_file_path(
cls, base_folder_path: str, absolute_file_path: str
) -> str:

splits: List[str] = absolute_file_path.split(base_folder_path, 1)
if len(splits) != 2:
logger.debug(
Expand Down Expand Up @@ -1104,7 +1103,6 @@ def from_looker_dict(
populate_sql_logic_in_descriptions: bool = False,
process_isolation_for_sql_parsing: bool = False,
) -> Optional["LookerView"]:

view_name = looker_view["name"]
logger.debug(f"Handling view {view_name} in model {model_name}")
# The sql_table_name might be defined in another view and this view is extending that view,
Expand Down Expand Up @@ -2087,7 +2085,6 @@ def get_internal_workunits(self) -> Iterable[MetadataWorkUnit]: # noqa: C901
)

if looker_viewfile is not None:

for raw_view in looker_viewfile.views:
raw_view_name = raw_view["name"]
if LookerRefinementResolver.is_refinement(raw_view_name):
Expand Down
4 changes: 2 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/nifi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class NifiSourceConfig(EnvConfigMixin):
description="Path to PEM file containing certs for the root CA(s) for the NiFi",
)

@root_validator
@root_validator(skip_on_failure=True)
def validate_auth_params(cla, values):
if values.get("auth") is NifiAuthType.CLIENT_CERT and not values.get(
"client_cert_file"
Expand All @@ -143,7 +143,7 @@ def validate_auth_params(cla, values):
)
return values

@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def validator_site_url_to_site_name(cls, values):
site_url_to_site_name = values.get("site_url_to_site_name")
site_url = values.get("site_url")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,7 @@ class PowerBiDashboardSourceConfig(
"Works for M-Query where native SQL is used for transformation.",
)

@root_validator
@classmethod
@root_validator(skip_on_failure=True)
def validate_extract_column_level_lineage(cls, values: Dict) -> Dict:
flags = [
"native_query_parsing",
Expand Down Expand Up @@ -445,7 +444,7 @@ def map_data_platform(cls, value):

return value

@root_validator(pre=False)
@root_validator(skip_on_failure=True)
def workspace_id_backward_compatibility(cls, values: Dict) -> Dict:
workspace_id = values.get("workspace_id")
workspace_id_pattern = values.get("workspace_id_pattern")
Expand Down
Loading

0 comments on commit 89dff8f

Please sign in to comment.