Skip to content

Commit

Permalink
core: use ParamSpec for better type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
karlicoss committed Jun 7, 2023
1 parent 2661e23 commit 0d5cc15
Showing 1 changed file with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime, date
import stat
from pathlib import Path
import sys
import time
import sqlite3
import typing
Expand Down Expand Up @@ -573,7 +574,25 @@ def __exit__(self, *args) -> None:
self.engine.dispose()


HashFunction = Callable[..., SourceHash]
R = TypeVar('R')
# ugh. python < 3.10 doesn't have ParamSpec and it seems tricky to backport it in compatible manner
if sys.version_info[:2] >= (3, 10) or TYPE_CHECKING:
if sys.version_info[:2] >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
P = ParamSpec('P')
CC = Callable[P, R] # need to give it a name, if inlined into bound=, mypy runs in a bug
PathProvider = Union[PathIsh, Callable[P, PathIsh]]
HashFunction = Callable[P, SourceHash]
else:
# just use some dummy types so runtime is happy
P = TypeVar('P')
CC = Any
PathProvider = Union[P, Any]
HashFunction = Union[P, Any]

F = TypeVar('F', bound=CC)


def default_hash(*args, **kwargs) -> SourceHash:
Expand Down Expand Up @@ -657,7 +676,6 @@ def new_dec(*args, **kwargs):
return new_dec


PathProvider = Union[PathIsh, Callable[..., PathIsh]]


def cachew_error(e: Exception) -> None:
Expand All @@ -677,12 +695,12 @@ def cachew_error(e: Exception) -> None:
@doublewrap
def cachew_impl(
func=None,
cache_path: Optional[PathProvider]=use_default_path,
force_file: bool=False,
cls=None,
depends_on: HashFunction=default_hash,
logger=None,
chunk_by=100,
cache_path: Optional[PathProvider[P]] = use_default_path,
force_file: bool = False,
cls: Optional[Type] = None,
depends_on: HashFunction[P] = default_hash,
logger: Optional[logging.Logger] = None,
chunk_by: int = 100,
# NOTE: allowed values for chunk_by depend on the system.
# some systems (to be more specific, sqlite builds), it might be too large and cause issues
# ideally this would be more defensive/autodetected, maybe with a warning?
Expand Down Expand Up @@ -773,7 +791,7 @@ def cachew_impl(
func =func,
cache_path =cache_path,
force_file =force_file,
cls =cls,
cls_ =cls,
depends_on =depends_on,
logger =logger,
chunk_by =chunk_by,
Expand All @@ -789,25 +807,25 @@ def binder(*args, **kwargs):


if TYPE_CHECKING:
F = TypeVar('F', bound=Callable)

# we need two versions due to @doublewrap
# this is when we just annotate as @cachew without any args
@overload # type: ignore[no-overload-impl]
def cachew(fun: F) -> F:
...

# TODO PathProvider here could benefit from paramspec??
# NOTE: we won't really be able to make sure the args of cache_path are the same as args of the wrapped function
# because when cachew() is called, we don't know anything about the wrapped function yet
# but at least it works for checking that cachew_path and depdns_on have the same args :shrug:
@overload
def cachew(
cache_path: Optional[PathProvider]=...,
cache_path: Optional[PathProvider[P]] = ...,
*,
force_file: bool=...,
cls=...,
depends_on: HashFunction=...,
logger=...,
chunk_by: int=...,
synthetic_key: Optional[str]=...,
force_file: bool = ...,
cls: Optional[Type] = ...,
depends_on: HashFunction[P] = ...,
logger: Optional[logging.Logger] = ...,
chunk_by: int = ...,
synthetic_key: Optional[str] = ...,
) -> Callable[[F], F]:
...
else:
Expand All @@ -826,12 +844,13 @@ def cname(func: Callable) -> str:
_DEPENDENCIES = 'dependencies'


class Context(NamedTuple):
@dataclasses.dataclass
class Context(Generic[P]):
func : Callable
cache_path : PathProvider
cache_path : PathProvider[P]
force_file : bool
cls : Type
depends_on : HashFunction
cls_ : Type
depends_on : HashFunction[P]
logger : logging.Logger
chunk_by : int
synthetic_key: Optional[str]
Expand All @@ -852,7 +871,7 @@ def composite_hash(self, *args, **kwargs) -> Dict[str, Any]:
if k in hsig.parameters or 'kwargs' in hsig.parameters
}
kwargs = {**defaults, **kwargs}
binder = NTBinder.make(tp=self.cls)
binder = NTBinder.make(tp=self.cls_)
schema = str(binder.columns) # todo not super nice, but works fine for now
hash_parts = {
'cachew' : CACHEW_VERSION,
Expand All @@ -871,14 +890,14 @@ def composite_hash(self, *args, **kwargs) -> Dict[str, Any]:

def cachew_wrapper(
*args,
_cachew_context: Context,
_cachew_context: Context[P],
**kwargs,
):
C = _cachew_context
func = C.func
cache_path = C.cache_path
force_file = C.force_file
cls = C.cls
cls = C.cls_
depends_on = C.depends_on
logger = C.logger
chunk_by = C.chunk_by
Expand Down

0 comments on commit 0d5cc15

Please sign in to comment.