Skip to content

Commit

Permalink
Add new strawberry.Parent annotation to support static resolvers with…
Browse files Browse the repository at this point in the history
… non-self types. (#3017)
  • Loading branch information
mattalbr authored Sep 14, 2023
1 parent dd5e388 commit d1ff467
Show file tree
Hide file tree
Showing 14 changed files with 387 additions and 28 deletions.
26 changes: 26 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Release type: minor

Adds new strawberry.Parent type annotation to support resolvers without use of self.

E.g.

```python
@dataclass
class UserRow:
id_: str


@strawberry.type
class User:
@strawberry.field
@staticmethod
async def name(parent: strawberry.Parent[UserRow]) -> str:
return f"User Number {parent.id}"


@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return UserRow(id_="1234")
```
26 changes: 26 additions & 0 deletions docs/errors/conflicting-arguments.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
title: Conflicting Arguments Error
---

# Conflicting Arguments Error

## Description

This error is thrown when you define a resolver with multiple arguments that
conflict with each other, like "self", "root", or any arguments annotated with
strawberry.Parent.

For example the following code will throw this error:

```python
import strawberry


@strawberry.type
class Query:
@strawberry.field
def hello(
self, root, parent: strawberry.Parent[str]
) -> str: # <-- self, root, and parent all identify the same input
return f"hello world"
```
44 changes: 44 additions & 0 deletions docs/types/resolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,50 @@ class User:
full_name: str = strawberry.field(resolver=full_name)
```

For either decorated resolvers or Python functions, there is a third option
which is to annotate one of the parameters with the type `Parent`. This is
particularly useful if the parent resolver returns a type other than the
strawberry type that the resolver is defined within, as it allows you to
specify the type of the parent object. This comes up particularly often
for resolvers that return ORMs:

```python
import dataclass

import strawberry


@dataclass
class UserRow:
id_: str


@strawberry.type
class User:
@strawberry.field
@staticmethod
async def name(parent: strawberry.Parent[UserRow]) -> str:
return f"User Number {parent.id}"


@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
# Even though this method is annotated as returning type `User`,
# which strawberry uses to define the GraphQL schema, we're
# not actually required to return an object of that type. Whatever
# object we do return will be passed on to child resolvers that
# request it via the `self`, `root`, or `strawberry.Parent` parameter.
# In this case, we return our ORM directly.
#
# Put differently, the GraphQL schema and associated resolvers come
# from the type annotations, but the actual object passed to the
# resolvers via `self`, `root`, or `strawberry.Parent` is whatever
# the parent resolvers return, regardless of type.
return UserRow(id_="1234")
```

## Accessing execution information

Sometimes it is useful to access the information for the current execution
Expand Down
2 changes: 2 additions & 0 deletions strawberry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .lazy_type import LazyType, lazy
from .mutation import mutation, subscription
from .object_type import asdict, input, interface, type
from .parent import Parent
from .permission import BasePermission
from .private import Private
from .scalars import ID
Expand All @@ -23,6 +24,7 @@
"UNSET",
"lazy",
"LazyType",
"Parent",
"Private",
"Schema",
"argument",
Expand Down
2 changes: 2 additions & 0 deletions strawberry/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from graphql import GraphQLError

from .conflicting_arguments import ConflictingArgumentsError
from .duplicated_type_name import DuplicatedTypeName
from .exception import StrawberryException, UnableToFindExceptionSource
from .handler import setup_exception_handler
Expand Down Expand Up @@ -178,6 +179,7 @@ class StrawberryGraphQLError(GraphQLError):
"WrongNumberOfResultsReturned",
"FieldWithResolverAndDefaultValueError",
"FieldWithResolverAndDefaultFactoryError",
"ConflictingArgumentsError",
"MissingQueryError",
"InvalidArgumentTypeError",
"InvalidDefaultFactoryError",
Expand Down
54 changes: 54 additions & 0 deletions strawberry/exceptions/conflicting_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, List, Optional

from .exception import StrawberryException
from .utils.source_finder import SourceFinder

if TYPE_CHECKING:
from strawberry.types.fields.resolver import StrawberryResolver

from .exception_source import ExceptionSource


class ConflictingArgumentsError(StrawberryException):
def __init__(
self,
resolver: StrawberryResolver,
arguments: List[str],
):
self.function = resolver.wrapped_func
self.argument_names = arguments

self.message = (
f"Arguments {self.argument_names_str} define conflicting resources. "
"Only one of these arguments may be defined per resolver."
)

self.rich_message = self.message

self.suggestion = (
f"Only one of {self.argument_names_str} may be defined per resolver."
)

self.annotation_message = self.suggestion

@cached_property
def argument_names_str(self) -> str:
return (
", ".join(f'"{name}"' for name in self.argument_names[:-1])
+ " and "
+ f'"{self.argument_names[-1]}"'
)

@cached_property
def exception_source(self) -> Optional[ExceptionSource]:
if self.function is None:
return None # pragma: no cover

source_finder = SourceFinder()

return source_finder.find_argument_from_object(
self.function, self.argument_names[1] # type: ignore
)
2 changes: 1 addition & 1 deletion strawberry/exceptions/utils/source_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def find_source(self, module: str) -> Optional[SourcePath]:
if not path.exists() or path.suffix != ".py":
return None # pragma: no cover

source = path.read_text()
source = path.read_text(encoding="utf-8")

return SourcePath(path=path, code=source)

Expand Down
39 changes: 39 additions & 0 deletions strawberry/parent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TypeVar
from typing_extensions import Annotated


class StrawberryParent:
...


T = TypeVar("T")

Parent = Annotated[T, StrawberryParent()]
Parent.__doc__ = """Represents a parameter holding the parent resolver's value.
This can be used when defining a resolver on a type when the parent isn't expected
to return the type itself.
Example:
>>> import strawberry
>>> from dataclasses import dataclass
>>>
>>> @dataclass
>>> class UserRow:
... id_: str
...
>>> @strawberry.type
... class User:
... @strawberry.field
... @staticmethod
... async def name(parent: strawberry.Parent[UserRow]) -> str:
... return f"User Number {parent.id}"
...
>>> @strawberry.type
>>> class Query:
... @strawberry.field
... def user(self) -> User:
... return UserRow(id_="1234")
...
"""
11 changes: 4 additions & 7 deletions strawberry/private.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TypeVar
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated

from strawberry.utils.typing import type_has_annotation


class StrawberryPrivate:
Expand All @@ -22,9 +24,4 @@ class StrawberryPrivate:


def is_private(type_: object) -> bool:
if get_origin(type_) is Annotated:
return any(
isinstance(argument, StrawberryPrivate) for argument in get_args(type_)
)

return False
return type_has_annotation(type_, StrawberryPrivate)
11 changes: 6 additions & 5 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _get_arguments(
# the following code allows to omit info and root arguments
# by inspecting the original resolver arguments,
# if it asks for self, the source will be passed as first argument
# if it asks for root, the source it will be passed as kwarg
# if it asks for root or parent, the source will be passed as kwarg
# if it asks for info, the info will be passed as kwarg

args = []
Expand All @@ -583,12 +583,13 @@ def _get_arguments(
if field.base_resolver.self_parameter:
args.append(source)

root_parameter = field.base_resolver.root_parameter
if root_parameter:
if parent_parameter := field.base_resolver.parent_parameter:
kwargs[parent_parameter.name] = source

if root_parameter := field.base_resolver.root_parameter:
kwargs[root_parameter.name] = source

info_parameter = field.base_resolver.info_parameter
if info_parameter:
if info_parameter := field.base_resolver.info_parameter:
kwargs[info_parameter.name] = info

return args, kwargs
Expand Down
Loading

0 comments on commit d1ff467

Please sign in to comment.