Skip to content

Commit

Permalink
Use ParamSpec to capture types of decorated solvers, tools, scorers, …
Browse files Browse the repository at this point in the history
…and metrics (#732)

* solver type-checking via paramspec

* paramspec for tools

* paramspec for metric and scorer

* dont import MetricType

* more type fixup in metric tests

* update changelog

* simplfy usetools signature
  • Loading branch information
jjallaire authored Oct 20, 2024
1 parent c23174a commit 3e913aa
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 87 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Google: compatibility with google-generativeai v0.8.3
- Llama: remove extraneous <|start_header_id|>assistant<|end_header_id|> if it appears in an assistant message.
- Use Dockerhub aisiuk/inspect-web-browser-tool image for web browser tool.
- Use ParamSpec to capture types of decorated solvers, tools, scorers, and metrics.
- Requirements: require semver>=3.0.0
- Added `delimiter` option to `csv_dataset()` (defaults to ",")
- Improve answer detection in multiple choice scorer.
Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/_util/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def registry_tag(
type: Callable[..., Any],
o: object,
info: RegistryInfo,
*args: list[Any],
**kwargs: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> None:
r"""Tag an object w/ registry info.
Expand Down
33 changes: 13 additions & 20 deletions src/inspect_ai/scorer/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import (
Any,
Callable,
ParamSpec,
Protocol,
TypeVar,
Union,
cast,
overload,
Expand Down Expand Up @@ -187,15 +187,10 @@ class Metric(Protocol):
def __call__(self, scores: list[Score]) -> Value: ...


MetricType = TypeVar("MetricType", Callable[..., Metric], type[Metric])
r"""Metric type.
Valid metric types include:
- Functions that return a Metric
- Classes derived from Metric
"""
P = ParamSpec("P")


def metric_register(metric: MetricType, name: str = "") -> MetricType:
def metric_register(metric: Callable[P, Metric], name: str = "") -> Callable[P, Metric]:
r"""Register a function or class as a metric.
Args:
Expand Down Expand Up @@ -229,19 +224,17 @@ def metric_create(name: str, **kwargs: Any) -> Metric:


@overload
def metric(name: str) -> Callable[..., MetricType]: ...
def metric(name: str) -> Callable[[Callable[P, Metric]], Callable[P, Metric]]: ...


@overload
# type: ignore
def metric(name: Callable[..., Metric]) -> Callable[..., Metric]: ...


@overload
def metric(name: type[Metric]) -> type[Metric]: ...
def metric(name: Callable[P, Metric]) -> Callable[P, Metric]: ...


def metric(name: str | MetricType) -> Callable[..., MetricType] | MetricType:
def metric(
name: str | Callable[P, Metric],
) -> Callable[[Callable[P, Metric]], Callable[P, Metric]] | Callable[P, Metric]:
r"""Decorator for registering metrics.
Args:
Expand All @@ -257,13 +250,13 @@ def metric(name: str | MetricType) -> Callable[..., MetricType] | MetricType:
# (b) Ensure that instances of Metric created by MetricType also
# carry registry info.
def create_metric_wrapper(
metric_type: MetricType, name: str | None = None
) -> MetricType:
metric_type: Callable[P, Metric], name: str | None = None
) -> Callable[P, Metric]:
metric_name = registry_name(
metric_type, name if name else getattr(metric_type, "__name__")
)

def metric_wrapper(*args: Any, **kwargs: Any) -> Metric:
def metric_wrapper(*args: P.args, **kwargs: P.kwargs) -> Metric:
metric = metric_type(*args, **kwargs)
registry_tag(
metric_type,
Expand All @@ -274,12 +267,12 @@ def metric_wrapper(*args: Any, **kwargs: Any) -> Metric:
)
return metric

return metric_register(cast(MetricType, metric_wrapper), metric_name)
return metric_register(cast(Callable[P, Metric], metric_wrapper), metric_name)

# for decorators with an explicit name, one more wrapper for the name
if isinstance(name, str):

def wrapper(metric_type: MetricType) -> MetricType:
def wrapper(metric_type: Callable[P, Metric]) -> Callable[P, Metric]:
return create_metric_wrapper(metric_type, name)

return wrapper
Expand Down
22 changes: 8 additions & 14 deletions src/inspect_ai/scorer/_scorer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import (
Any,
Callable,
ParamSpec,
Protocol,
TypeVar,
cast,
runtime_checkable,
)
Expand Down Expand Up @@ -42,18 +42,12 @@ async def __call__(
) -> Score: ...


ScorerType = TypeVar("ScorerType", Callable[..., Scorer], type[Scorer])
r"""Scorer type.
Valid scorer types include:
- Functions that return a Scorer
- Classes derived from Scorer
"""
P = ParamSpec("P")


def scorer_register(
scorer: ScorerType, name: str = "", metadata: dict[str, Any] = {}
) -> ScorerType:
scorer: Callable[P, Scorer], name: str = "", metadata: dict[str, Any] = {}
) -> Callable[P, Scorer]:
r"""Register a function or class as a scorer.
Args:
Expand Down Expand Up @@ -91,7 +85,7 @@ def scorer(
metrics: list[Metric | dict[str, list[Metric]]] | dict[str, list[Metric]],
name: str | None = None,
**metadata: Any,
) -> Callable[[Callable[..., Scorer]], Callable[..., Scorer]]:
) -> Callable[[Callable[P, Scorer]], Callable[P, Scorer]]:
r"""Decorator for registering scorers.
Args:
Expand All @@ -109,14 +103,14 @@ def scorer(
"""

def wrapper(scorer_type: ScorerType) -> ScorerType:
def wrapper(scorer_type: Callable[P, Scorer]) -> Callable[P, Scorer]:
# determine the name (explicit or implicit from object)
scorer_name = registry_name(
scorer_type, name if name else getattr(scorer_type, "__name__")
)

# wrap instantiations of scorer so they carry registry info and metrics
def scorer_wrapper(*args: Any, **kwargs: Any) -> Scorer:
def scorer_wrapper(*args: P.args, **kwargs: P.kwargs) -> Scorer:
scorer = scorer_type(*args, **kwargs)

if not is_callable_coroutine(scorer):
Expand All @@ -139,7 +133,7 @@ def scorer_wrapper(*args: Any, **kwargs: Any) -> Scorer:

# register the scorer
return scorer_register(
scorer=cast(ScorerType, scorer_wrapper),
scorer=cast(Callable[P, Scorer], scorer_wrapper),
name=scorer_name,
metadata={SCORER_METRICS: metrics} | metadata,
)
Expand Down
3 changes: 1 addition & 2 deletions src/inspect_ai/solver/_basic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from inspect_ai.scorer._score import score
from inspect_ai.solver._chain import chain
from inspect_ai.tool._tool import Tool, ToolResult, tool
from inspect_ai.tool._tool_def import ToolDef
from inspect_ai.tool._tool_with import tool_with

from ._prompt import system_message
Expand Down Expand Up @@ -50,7 +49,7 @@ class BasicAgentDeprecatedArgs(TypedDict, total=False):
def basic_agent(
*,
init: Solver | list[Solver] | None = None,
tools: list[Tool | ToolDef] | Solver | None = None,
tools: list[Tool] | Solver | None = None,
cache: bool | CachePolicy = False,
max_attempts: int = 1,
message_limit: int | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/solver/_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@solver
def prompt_template(template: str, **params: dict[str, Any]) -> Solver:
def prompt_template(template: str, **params: Any) -> Solver:
"""Parameterized prompt template.
Prompt template containing a `{prompt}` placeholder and any
Expand All @@ -37,7 +37,7 @@ async def solve(state: TaskState, generate: Generate) -> TaskState:


@solver
def system_message(template: str, **params: dict[str, Any]) -> Solver:
def system_message(template: str, **params: Any) -> Solver:
"""Solver which inserts a system message into the conversation.
System message template containing any number of optional `params`.
Expand Down
41 changes: 16 additions & 25 deletions src/inspect_ai/solver/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
Any,
Callable,
Literal,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
runtime_checkable,
Expand Down Expand Up @@ -98,20 +98,14 @@ async def __call__(
) -> TaskState: ...


SolverType = TypeVar("SolverType", Callable[..., Solver], type[Solver])
r"""Solver type.
P = ParamSpec("P")

Valid solver types include:
- Functions that return a Solver
- Classes derived from Solver
"""


def solver_register(solver: SolverType, name: str = "") -> SolverType:
def solver_register(solver: Callable[P, Solver], name: str = "") -> Callable[P, Solver]:
r"""Register a function or class as a solver.
Args:
solver (SolverType):
solver (Callable[P, Solver]):
Function that returns a Solver or class derived Solver.
name (str): Name of solver (Optional, defaults to object name)
Expand All @@ -137,25 +131,22 @@ def solver_create(name: str, **kwargs: Any) -> Solver:


@overload
def solver(name: str) -> Callable[..., SolverType]: ...


@overload
# type: ignore
def solver(name: Callable[..., Solver]) -> Callable[..., Solver]: ...
def solver(name: str) -> Callable[[Callable[P, Solver]], Callable[P, Solver]]: ...


@overload
def solver(name: type[Solver]) -> type[Solver]: ...
def solver(name: Callable[P, Solver]) -> Callable[P, Solver]: ...


def solver(name: str | SolverType) -> Callable[..., SolverType] | SolverType:
def solver(
name: str | Callable[P, Solver],
) -> Callable[[Callable[P, Solver]], Callable[P, Solver]] | Callable[P, Solver]:
r"""Decorator for registering solvers.
Args:
name: (str | SolverType):
name: (str | Callable[P, Solver]):
Optional name for solver. If the decorator has no name
argument then the name of the underlying SolverType
argument then the name of the underlying Callable[P, Solver]
object will be used to automatically assign a name.
Returns:
Expand Down Expand Up @@ -183,13 +174,13 @@ def solve(state: TaskState, generate: Generate) -> None:
# (b) Ensure that instances of Solver created by SolverType also
# carry registry info.
def create_solver_wrapper(
solver_type: SolverType, name: str | None = None
) -> SolverType:
solver_type: Callable[P, Solver], name: str | None = None
) -> Callable[P, Solver]:
solver_name = registry_name(
solver_type, name if name else getattr(solver_type, "__name__")
)

def solver_wrapper(*args: Any, **kwargs: dict[str, Any]) -> Solver:
def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
solver = solver_type(*args, **kwargs)

if not is_callable_coroutine(solver):
Expand Down Expand Up @@ -234,12 +225,12 @@ async def registered_solver(

return registered_solver

return solver_register(cast(SolverType, solver_wrapper), solver_name)
return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)

# for decorators with an explicit name, one more wrapper for the name
if isinstance(name, str):

def wrapper(solver_type: SolverType) -> SolverType:
def wrapper(solver_type: Callable[..., Solver]) -> Callable[..., Solver]:
return create_solver_wrapper(solver_type, name)

return wrapper
Expand Down
2 changes: 1 addition & 1 deletion src/inspect_ai/solver/_use_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@solver
def use_tools(
*tools: Tool | ToolDef | list[Tool | ToolDef],
*tools: Tool | list[Tool],
tool_choice: ToolChoice | None = "auto",
) -> Solver:
"""
Expand Down
Loading

0 comments on commit 3e913aa

Please sign in to comment.