diff --git a/scripts/types_generator/main.py b/scripts/types_generator/main.py index 129bc866..ddc7eac8 100644 --- a/scripts/types_generator/main.py +++ b/scripts/types_generator/main.py @@ -50,6 +50,7 @@ def generate_models(schema_path: Path, output: Path, extra_template_data: Option target_python_version=PythonVersion.PY_39, base_class=ExtractorConfig.base_model_class, additional_imports=[ + "warnings", "deprecated", "pydantic.field_validator", "pydantic.computed_field", diff --git a/scripts/types_generator/schema_aliases.yaml b/scripts/types_generator/schema_aliases.yaml index 39d67db5..0897365c 100644 --- a/scripts/types_generator/schema_aliases.yaml +++ b/scripts/types_generator/schema_aliases.yaml @@ -822,3 +822,23 @@ model_extensions: PromptModerationParameters: custom_base_class: ModerationParameters + + TuneResultDatapointValidationLossData: + custom_body: | + @field_validator("epoch", mode="before") + @classmethod + def _validate_epoch(cls, value: Any): + result_value = int(value) + if result_value != float(value): + warnings.warn(f'The epoch was rounded down from {value} to {result_value}', stacklevel=4) + return result_value + + TuneResultDatapointLossData: + custom_body: | + @field_validator("epoch", mode="before") + @classmethod + def _validate_epoch(cls, value: Any): + result_value = int(value) + if result_value != float(value): + warnings.warn(f'The epoch was rounded down from {value} to {result_value}', stacklevel=4) + return result_value diff --git a/src/genai/schema/_api.py b/src/genai/schema/_api.py index 986230b1..09a7275b 100644 --- a/src/genai/schema/_api.py +++ b/src/genai/schema/_api.py @@ -3,11 +3,12 @@ from __future__ import annotations +import warnings from datetime import date from enum import Enum from typing import Any, Literal, Optional, Union -from pydantic import AwareDatetime, Field, RootModel +from pydantic import AwareDatetime, Field, RootModel, field_validator from genai._types import ApiBaseModel @@ -2208,6 +2209,14 @@ class TuneResultDatapointLossData(ApiBaseModel): step: Optional[int] = None value: float + @field_validator("epoch", mode="before") + @classmethod + def _validate_epoch(cls, value: Any): + result_value = int(value) + if result_value != float(value): + warnings.warn(f"The epoch was rounded down from {value} to {result_value}", stacklevel=4) + return result_value + class TuneResultDatapointValidationLoss(ApiBaseModel): data: TuneResultDatapointValidationLossData @@ -2219,6 +2228,14 @@ class TuneResultDatapointValidationLossData(ApiBaseModel): step: Optional[int] = None value: float + @field_validator("epoch", mode="before") + @classmethod + def _validate_epoch(cls, value: Any): + result_value = int(value) + if result_value != float(value): + warnings.warn(f"The epoch was rounded down from {value} to {result_value}", stacklevel=4) + return result_value + class TuneResultFiles(ApiBaseModel): created_at: Optional[AwareDatetime] = None