Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to signature.py #964

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 117 additions & 35 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,41 @@
import types
import typing
from copy import deepcopy
from typing import Any, Dict, Tuple, Type, Union # noqa: UP035
from typing import Dict, List, Tuple, Type, Union, Optional, Mapping, cast, Callable # noqa: UP035

import pydantic.fields
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

import dsp
from dspy.signatures.field import InputField, OutputField, new_to_old_field


def signature_to_template(signature) -> dsp.Template:
def signature_to_template(signature: "SignatureMeta") -> dsp.Template:
"""Convert from new to legacy format."""
return dsp.Template(
signature.instructions,
**{name: new_to_old_field(field) for name, field in signature.fields.items()},
)


def _default_instructions(cls) -> str:
def _default_instructions(cls: "SignatureMeta") -> str:
inputs_ = ", ".join([f"`{field}`" for field in cls.input_fields])
outputs_ = ", ".join([f"`{field}`" for field in cls.output_fields])
return f"Given the fields {inputs_}, produce the fields {outputs_}."


class SignatureMeta(type(BaseModel)):
# I don't love this, but mypy cannot deal with the dynamic lookup of this metaclass - rpg
PydanticMetaClass: typing.TypeAlias = pydantic._internal._model_construction.ModelMetaclass

assert PydanticMetaClass == type(pydantic.BaseModel), \
"Implementation of Signature MetaClass depends on the internals of Pydantic, which have changed."


class SignatureMeta(PydanticMetaClass):

model_fields: Dict[str, FieldInfo]

def __call__(cls, *args, **kwargs): # noqa: ANN002
if cls is Signature:
return make_signature(*args, **kwargs)
Expand All @@ -43,7 +54,10 @@ def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
namespace["__annotations__"] = raw_annotations

# Let Pydantic do its thing
cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs)
cls = typing.cast(
SignatureMeta,
super().__new__(mcs, signature_name, bases, namespace, **kwargs),
)

# If we don't have instructions, it might be because we are a derived generic type.
# In that case, we should inherit the instructions from the base class.
Expand Down Expand Up @@ -91,7 +105,9 @@ def instructions(cls) -> str:
return getattr(cls, "__doc__", "")

def with_instructions(cls, instructions: str) -> Type["Signature"]:
return Signature(cls.fields, instructions)
# FIXME: mypy believes that this is an illegitimate invocation of Signature,
# which it believes can only take 1 argument - rpg
return Signature(cls.fields, instructions) # type: ignore

@property
def fields(cls) -> dict[str, FieldInfo]:
Expand All @@ -104,14 +120,16 @@ def with_updated_fields(cls, name, type_=None, **kwargs) -> Type["Signature"]:
Returns a new Signature type with the field, name, updated
with fields[name].json_schema_extra[key] = value.
"""
fields_copy = deepcopy(cls.fields)
fields_copy: Dict[str, FieldInfo] = deepcopy(cls.fields)
fields_copy[name].json_schema_extra = {
**fields_copy[name].json_schema_extra,
# Unable to make mypy realize that the following is OK
**fields_copy[name].json_schema_extra, # type: ignore
**kwargs,
}
if type_ is not None:
fields_copy[name].annotation = type_
return Signature(fields_copy, cls.instructions)
# FIXME: another place where mypy doesn't like the signature of Signature
return cast(Type["Signature"], Signature(fields_copy, cls.instructions)) # type: ignore

@property
def input_fields(cls) -> dict[str, FieldInfo]:
Expand All @@ -122,16 +140,25 @@ def output_fields(cls) -> dict[str, FieldInfo]:
return cls._get_fields_with_type("output")

def _get_fields_with_type(cls, field_type) -> dict[str, FieldInfo]:
return {k: v for k, v in cls.model_fields.items() if v.json_schema_extra["__dspy_field_type"] == field_type}
# FIXME: the following assumes that json_schema_extra will always be
# a JsonDict, but according to mypy it could be None, or a Callable.
# I'm trusting the code, but uncomfortable.
return {
k: v
for k, v in cls.model_fields.items()
if v.json_schema_extra["__dspy_field_type"] == field_type # type: ignore
}

def prepend(cls, name, field, type_=None) -> Type["Signature"]:
def prepend(cls, name, field: FieldInfo, type_=None) -> Type["Signature"]:
return cls.insert(0, name, field, type_)

def append(cls, name, field, type_=None) -> Type["Signature"]:
def append(cls, name, field: FieldInfo, type_=None) -> Type["Signature"]:
return cls.insert(-1, name, field, type_)

def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]:
# It's posisble to set the type as annotation=type in pydantic.Field(...)
def insert(
cls, index: int, name: str, field: FieldInfo, type_: Optional[Type] = None
) -> Type["Signature"]:
# It's possible to set the type as annotation=type in pydantic.Field(...)
# But this may be annoying for users, so we allow them to pass the type
if type_ is None:
type_ = field.annotation
Expand All @@ -142,7 +169,11 @@ def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signa
output_fields = list(cls.output_fields.items())

# Choose the list to insert into based on the field type
lst = input_fields if field.json_schema_extra["__dspy_field_type"] == "input" else output_fields
lst = (
input_fields
if field.json_schema_extra["__dspy_field_type"] == "input"
else output_fields
)
# We support negative insert indices
if index < 0:
index += len(lst) + 1
Expand All @@ -151,12 +182,15 @@ def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signa
lst.insert(index, (name, (type_, field)))

new_fields = dict(input_fields + output_fields)
return Signature(new_fields, cls.instructions)
# FIXME: another Signature call that mypy does not understand
return Signature(new_fields, cls.instructions) # type: ignore

def equals(cls, other) -> bool:
"""Compare the JSON schema of two Pydantic models."""
if not isinstance(other, type) or not issubclass(other, BaseModel):
return False
# this cast tells mypy that `other` will have instructions and fields
other = cast(Union[Type["Signature"]], other)
if cls.instructions != other.instructions:
return False
for name in cls.fields.keys() | other.fields.keys():
Expand Down Expand Up @@ -198,28 +232,44 @@ def __repr__(cls):
#
# For compatibility with the legacy dsp format, you can use the signature_to_template function.
#
class Signature(BaseModel, metaclass=SignatureMeta):
"" # noqa: D419

class Signature(BaseModel, metaclass=SignatureMeta):
"""""" # noqa: D419
# FIXME: This attempt to tell mypy about the parameters to `Signature()`
# falls afoul of the definition in BaseModel, which mypy believes to be:
# `Callable[[KwArg(Any)], None]`
# Here's the header of that function: def __init__(self, /, **data: Any) -> None:
# I thought that the `/` meant that the callable did *not* only take kwargs.
__init__: Callable[[Type["Signature"], str, Optional[str]], None]
# Note: Don't put a docstring here, as it will become the default instructions
# for any signature that doesn't define it's own instructions.
# for any signature that doesn't define its own instructions.
pass


def ensure_signature(signature: Union[str, Type[Signature]], instructions=None) -> Signature:


def ensure_signature(
signature: Union[str, Type[Signature]], instructions: Optional[str]=None
) -> Type[Signature]:
# FIXME: the following would force me to type ensure_signature() as -> Optional[Type[Signature]]
# instead of simply -> Type[Signature]. Should this be an error?
if signature is None:
return None
if isinstance(signature, str):
return Signature(signature, instructions)
# FIXME: According to mypy, Signature cannot be called with 2 arguments.
# return Signature(signature, instructions)
return make_signature(signature, instructions=instructions)
if instructions is not None:
raise ValueError("Don't specify instructions when initializing with a Signature")
raise ValueError(
"Don't specify instructions when initializing with a Signature"
)
return signature


def make_signature(
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
instructions: str = None,
signature_name: str = "StringSignature",
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
instructions: Optional[str] = None,
signature_name: str = "StringSignature",
) -> Type[Signature]:
"""Create a new Signature type with the given fields and instructions.

Expand Down Expand Up @@ -264,7 +314,7 @@ def make_signature(

# Default prompt when no instructions are provided
if instructions is None:
sig = Signature(signature, "") # Simple way to parse input/output fields
sig: Type["Signature"] = Signature(signature, "") # Simple way to parse input/output fields
instructions = _default_instructions(sig)

return create_model(
Expand All @@ -275,9 +325,11 @@ def make_signature(
)


def _parse_signature(signature: str) -> Tuple[Type, Field]:
def _parse_signature(signature: str) -> Dict[str, Tuple[Type, pydantic.fields.FieldInfo]]:
if signature.count("->") != 1:
raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.")
raise ValueError(
f"Invalid signature format: '{signature}', must contain exactly one '->'."
)

inputs_str, outputs_str = signature.split("->")

Expand All @@ -290,17 +342,37 @@ def _parse_signature(signature: str) -> Tuple[Type, Field]:
return fields


def _parse_arg_string(string: str, names=None) -> Dict[str, str]:
def _parse_arg_string(string: str, names=None) -> typing.Iterator[Tuple[str, Type]]:
args = ast.parse("def f(" + string + "): pass").body[0].args.args
names = [arg.arg for arg in args]
types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args]
# noinspection PyShadowingNames
types = [
str if arg.annotation is None else _parse_type_node(arg.annotation)
for arg in args
]
return zip(names, types)


def _parse_type_node(node, names=None) -> Any:
"""Recursively parse an AST node representing a type annotation.
def _parse_type_node(
node: Union[ast.Module, ast.Expr, ast.Name, ast.Subscript, ast.Tuple, ast.Call, ast.stmt, ast.expr],
names: Optional[Mapping[str, Type]] = None,
) -> Union[Type, Tuple[Type, ...]]:
"""
Recursively parse an AST node representing a type annotation.


without using structural pattern matching introduced in Python 3.10.

Parameters
----------
node : ast node
names : mapping from type names to types, optional
By default this mapping is populated by the names defined by the `typing`
module.
Returns
-------
type annotated on node

"""
if names is None:
names = typing.__dict__
Expand All @@ -316,6 +388,7 @@ def _parse_type_node(node, names=None) -> Any:
return _parse_type_node(value, names)

if isinstance(node, ast.Name):
# Look up the node's identifier in the set of available types.
id_ = node.id
if id_ in names:
return names[id_]
Expand All @@ -331,11 +404,20 @@ def _parse_type_node(node, names=None) -> Any:

if isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)
# FIXME: the following assumes that none of the recursive calls can return a Tuple
# I don't understand the AST well enough to know if this assumption is justified.
return tuple(_parse_type_node(elt, names) for elt in elts) # type: ignore

if isinstance(node, ast.Call):
if node.func.id == "Field":
keys = [kw.arg for kw in node.keywords]
# FIXME: According to mypy, node.func has type `expr` and `expr` does not
# have an `id` field. I *suspect* this means that the `func` property of the
# Call could be filled with an `expr` that *evaluates to* a function, instead
# of a function? Here's what I see in the documentation of the ast module:
# Call(expr func, expr* args, keyword* keywords)
# Do we need to check that this is the right sort of AST node before
# checking its id? And if so, do we check `isinstance(node.fun, ast.Name)`?
if isinstance(node.func, ast.Name) and node.func.id == "Field":
keys: List[str] = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
return Field(**dict(zip(keys, values)))

Expand Down