Skip to content

Commit

Permalink
minor revision to the base function model
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Feb 2, 2024
1 parent 25d69d3 commit dcac7a7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
68 changes: 56 additions & 12 deletions async_openai/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import jinja2
import functools
from abc import ABC
from pydantic import PrivateAttr, BaseModel
from pydantic import Field, BaseModel
# from lazyops.types import BaseModel
from lazyops.utils.times import Timer
from lazyops.libs.proxyobj import ProxyObject
from lazyops.types.models import schema_extra, PYD_VERSION
from async_openai.utils.fixjson import resolve_json
from . import errors

from typing import Optional, Any, Dict, List, Union, Type, Tuple, Awaitable, Generator, AsyncGenerator, TypeVar, TYPE_CHECKING

if PYD_VERSION == 2:
from pydantic import ConfigDict

if TYPE_CHECKING:
from async_openai import ChatResponse, ChatRoute
from async_openai.types.resources import Usage
Expand All @@ -29,11 +33,14 @@


class BaseFunctionModel(BaseModel):
_name: Optional[str] = PrivateAttr(None)
function_name: Optional[str] = Field(None, hidden = True)
function_model: Optional[str] = Field(None, hidden = True)
function_duration: Optional[float] = Field(None, hidden = True)

if TYPE_CHECKING:
usage: Optional[Usage]
function_usage: Optional[Usage]
else:
usage: Optional[Any] = None
function_usage: Optional[Any] = Field(None, hidden = True)

def update(
self,
Expand Down Expand Up @@ -84,6 +91,47 @@ def is_valid(self) -> bool:
Returns whether the function data is valid
"""
return True

def _set_values_from_response(
self,
response: 'ChatResponse',
name: Optional[str] = None,
**kwargs
) -> 'BaseFunctionModel':
"""
Sets the values from the response
"""
if name: self.function_name = name
self.function_usage = response.usage
if response.response_ms: self.function_duration = response.response_ms / 1000
self.function_model = response.model

@property
def function_cost(self) -> Optional[float]:
"""
Returns the function consumption
"""
if not self.function_model: return None
if not self.function_usage: return None
from async_openai.types.context import ModelContextHandler
return ModelContextHandler.get_consumption_cost(self.function_model, self.function_usage)

@property
def function_cost_string(self) -> Optional[str]:
"""
Returns the function consumption as a pretty string
"""
return f"${self.function_cost:.2f}" if self.function_cost else None


if PYD_VERSION == 2:
model_config = ConfigDict(json_schema_extra = schema_extra, arbitrary_types_allowed = True)
else:
class Config:
json_schema_extra = schema_extra
arbitrary_types_allowed = True



FunctionSchemaT = TypeVar('FunctionSchemaT', bound = BaseFunctionModel)
FunctionResultT = TypeVar('FunctionResultT', bound = BaseFunctionModel)
Expand Down Expand Up @@ -453,17 +501,13 @@ def parse_response(
schema = schema or self.schema
try:
result = schema.model_validate(response.function_results[0].arguments, from_attributes = True)
if include_name:
result._name = self.name
result.usage = response.usage
result._set_values_from_response(response, name = self.name if include_name else None)
return result
except Exception as e:
self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object: {e}\n{response.text}\n{response.function_results[0].arguments}")
try:
result = schema.model_validate(resolve_json(response.function_results[0].arguments), from_attributes = True)
if include_name:
result._name = self.name
result.usage = response.usage
result._set_values_from_response(response, name = self.name if include_name else None)
return result
except Exception as e:
self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object after fixing")
Expand Down Expand Up @@ -864,7 +908,7 @@ def execute(
result: 'FunctionResultT' = self.cache.fetch(key)
if result:
if isinstance(result, dict): result = function.schema.model_validate(result)
result._name = function.name
result.function_name = function.name
cache_hit = True

if result is None:
Expand Down Expand Up @@ -905,7 +949,7 @@ async def aexecute(
result: 'FunctionResultT' = await self.cache.afetch(key)
if result:
if isinstance(result, dict): result = function.schema.model_validate(result)
result._name = function.name
result.function_name = function.name
cache_hit = True

if result is None:
Expand Down
2 changes: 1 addition & 1 deletion async_openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.0.50rc1'
VERSION = '0.0.50rc2'

0 comments on commit dcac7a7

Please sign in to comment.