Skip to content

Commit

Permalink
update iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Feb 6, 2024
1 parent dcac7a7 commit baf861b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
3 changes: 3 additions & 0 deletions async_openai/types/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def get_consumption_cost(cls, model_name: str, usage: 'Usage', **kwargs) -> floa
# Switch the 35 -> 3.5
if '35' in model_name: model_name = model_name.replace('35', '3.5')
model = cls[model_name]
if isinstance(usage, dict):
from .resources import Usage
usage = Usage(**usage)
return model.get_costs(usage = usage, **kwargs)

def resolve_model_name(cls, model_name: str) -> str:
Expand Down
109 changes: 96 additions & 13 deletions async_openai/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jinja2
import functools
import inspect
from abc import ABC
from pydantic import Field, BaseModel
# from lazyops.types import BaseModel
Expand All @@ -15,7 +16,7 @@
from async_openai.utils.fixjson import resolve_json
from . import errors

from typing import Optional, Any, Dict, List, Union, Type, Tuple, Awaitable, Generator, AsyncGenerator, TypeVar, TYPE_CHECKING
from typing import Optional, Any, Set, Dict, List, Union, Type, Tuple, Awaitable, Generator, AsyncGenerator, TypeVar, TYPE_CHECKING

if PYD_VERSION == 2:
from pydantic import ConfigDict
Expand Down Expand Up @@ -644,6 +645,13 @@ async def arun_function(
"""
messages, model = await self.aprepare_function_inputs(model = model, **kwargs)
return await self.arun_function_loop(messages = messages, model = model, **kwargs)

def get_function_kwargs(self) -> Dict[str, Any]:
"""
Returns the function kwargs
"""
sig = inspect.signature(self.arun_function)
return [p.name for p in sig.parameters.values() if p.kind in {p.KEYWORD_ONLY, p.VAR_KEYWORD, p.POSITIONAL_OR_KEYWORD} and p.name not in {'kwargs', 'args', 'model'}]

"""
Handle a Loop
Expand Down Expand Up @@ -849,17 +857,23 @@ def register_function(
self.functions[name] = func
self.autologger.info(f"Registered Function: |g|{name}|e|", colored=True)

def create_hash(self, **kwargs) -> str:
def create_hash(self, *args, **kwargs) -> str:
"""
Creates a hash
"""
return self._hash_func(self._pickle.dumps(kwargs)).hexdigest()
key = args or ()
kwargs = {k: v for k, v in kwargs.items() if v is not None}
sorted_items = sorted(kwargs.items())
for item in sorted_items:
key += item
key = ':'.join(str(k) for k in key)
return self._hash_func(key).hexdigest()

async def acreate_hash(self, **kwargs) -> str:
async def acreate_hash(self, *args, **kwargs) -> str:
"""
Creates a hash
"""
return await self.api.pooler.asyncish(self.create_hash, **kwargs)
return await self.api.pooler.asyncish(self.create_hash, *args, **kwargs)

def _get_function(self, name: str) -> Optional[BaseFunction]:
"""
Expand All @@ -879,15 +893,48 @@ def get(self, name: Union[str, 'FunctionT']) -> Optional['FunctionT']:
return name if isinstance(name, BaseFunction) else self._get_function(name)


def parse_iterator_func(
self,
function: 'FunctionT',
*args,
with_index: Optional[bool] = False,
**function_kwargs,
) -> Tuple[int, Set, Dict[str, Any]]:
"""
Parses the iterator function kwargs
"""
func_iter_arg = args[0]
args = args[1:]
idx = None
if with_index: idx, item = func_iter_arg
else: item = func_iter_arg
_func_kwargs = function.get_function_kwargs()
if isinstance(item, dict) and any(k in _func_kwargs for k in item):
function_kwargs.update(item)
else:
# Get the missing function kwargs
_added = False
for k in _func_kwargs:
if k not in function_kwargs:
function_kwargs[k] = item
# self.autologger.info(f"Added missing function kwarg: {k} = {item}", prefix = function.name, colored = True)
_added = True
break
if not _added:
# If not, then add the item as the first argument
args = (item,) + args
return idx, args, function_kwargs

def execute(
self,
function: Union['FunctionT', str],
*args,
item_hashkey: Optional[str] = None,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
with_index: Optional[bool] = False,
**function_kwargs
) -> Optional['FunctionSchemaT']:
) -> Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]]:
"""
Runs the function
"""
Expand All @@ -896,6 +943,11 @@ def execute(
if overwrite and self.check_value_present(overrides, f'{function.name}.cachable'):
cachable = False

# Iterators
is_iterator = function_kwargs.pop('_is_iterator', False)
if is_iterator:
idx, args, function_kwargs = self.parse_iterator_func(function, *args, with_index = with_index, **function_kwargs)

if item_hashkey is None: item_hashkey = self.create_hash(**function_kwargs)
key = f'{item_hashkey}.{function.name}'
if function.has_diff_model_than_default:
Expand All @@ -917,6 +969,8 @@ def execute(
self.cache.set(key, result)

self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True)
if is_iterator and with_index:
return idx, result if function.is_valid_response(result) else (idx, None)
return result if function.is_valid_response(result) else None


Expand All @@ -927,17 +981,24 @@ async def aexecute(
item_hashkey: Optional[str] = None,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
with_index: Optional[bool] = False,
**function_kwargs
) -> Optional['FunctionSchemaT']:
) -> Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]]:
# sourcery skip: low-code-quality
"""
Runs the function
"""
overwrite = overrides and 'functions' in overrides
function = self.get(function)
if overwrite and self.check_value_present(overrides, f'{function.name}.cachable'):
cachable = False

# Iterators
is_iterator = function_kwargs.pop('_is_iterator', False)
if is_iterator:
idx, args, function_kwargs = self.parse_iterator_func(function, *args, with_index = with_index, **function_kwargs)

if item_hashkey is None: item_hashkey = await self.acreate_hash(**function_kwargs)
if item_hashkey is None: item_hashkey = await self.acreate_hash(*args, **function_kwargs)
key = f'{item_hashkey}.{function.name}'
if function.has_diff_model_than_default:
key += f'.{function.default_model_func}'
Expand All @@ -958,6 +1019,8 @@ async def aexecute(
await self.cache.aset(key, result)

self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True)
if is_iterator and with_index:
return idx, result if function.is_valid_response(result) else (idx, None)
return result if function.is_valid_response(result) else None


Expand Down Expand Up @@ -1005,23 +1068,28 @@ def check_value_present(
def map(
self,
function: Union['FunctionT', str],
iterable_kwargs: List[Dict[str, Any]],
iterable_kwargs: List[Union[Dict[str, Any], Any]],
*args,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = True,
with_index: Optional[bool] = False,
**function_kwargs
) -> List[Optional['FunctionSchemaT']]:
) -> List[Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]]]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.execute,
function,
# *args,
cachable = cachable,
overrides = overrides,
_is_iterator = True,
with_index = with_index,
**function_kwargs
)
if with_index: iterable_kwargs = list(enumerate(iterable_kwargs))
return self.api.pooler.map(partial, iterable_kwargs, *args, return_ordered = return_ordered)

async def amap(
Expand All @@ -1033,18 +1101,23 @@ async def amap(
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = True,
concurrency_limit: Optional[int] = None,
with_index: Optional[bool] = False,
**function_kwargs
) -> List[Optional['FunctionSchemaT']]:
) -> List[Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]]]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.aexecute,
function,
# *args,
cachable = cachable,
overrides = overrides,
_is_iterator = True,
with_index = with_index,
**function_kwargs
)
if with_index: iterable_kwargs = list(enumerate(iterable_kwargs))
return await self.api.pooler.amap(partial, iterable_kwargs, *args, return_ordered = return_ordered, concurrency_limit = concurrency_limit)

def iterate(
Expand All @@ -1055,18 +1128,23 @@ def iterate(
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = False,
with_index: Optional[bool] = False,
**function_kwargs
) -> Generator[Optional['FunctionSchemaT'], None, None]:
) -> Generator[Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]], None, None]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.execute,
function,
# *args,
cachable = cachable,
overrides = overrides,
_is_iterator = True,
with_index = with_index,
**function_kwargs
)
if with_index: iterable_kwargs = list(enumerate(iterable_kwargs))
return self.api.pooler.iterate(partial, iterable_kwargs, *args, return_ordered = return_ordered)

def aiterate(
Expand All @@ -1078,18 +1156,23 @@ def aiterate(
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = False,
concurrency_limit: Optional[int] = None,
with_index: Optional[bool] = False,
**function_kwargs
) -> AsyncGenerator[Optional['FunctionSchemaT'], None]:
) -> AsyncGenerator[Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]], None]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.aexecute,
function,
# *args,
cachable = cachable,
overrides = overrides,
_is_iterator = True,
with_index = with_index,
**function_kwargs
)
if with_index: iterable_kwargs = list(enumerate(iterable_kwargs))
return self.api.pooler.aiterate(partial, iterable_kwargs, *args, return_ordered = return_ordered, concurrency_limit = concurrency_limit)

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion async_openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.0.50rc2'
VERSION = '0.0.50rc3'

0 comments on commit baf861b

Please sign in to comment.