Skip to content

Commit

Permalink
Improve interface performance using resolve_type (#1949)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com>
  • Loading branch information
skilkis and patrick91 authored Jul 14, 2023
1 parent 1360e30 commit 92e1151
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 19 deletions.
32 changes: 32 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Release type: minor

Improve the time complexity of `strawberry.interface` using `resolve_type`.
Achieved time complexity is now O(1) with respect to the number of
implementations of an interface. Previously, the use of `is_type_of` resulted
in a worst-case performance of O(n).

**Before**:

```shell
---------------------------------------------------------------------------
Name (time in ms) Min Max
---------------------------------------------------------------------------
test_interface_performance[1] 18.0224 (1.0) 50.3003 (1.77)
test_interface_performance[16] 22.0060 (1.22) 28.4240 (1.0)
test_interface_performance[256] 69.1364 (3.84) 76.1349 (2.68)
test_interface_performance[4096] 219.6461 (12.19) 231.3732 (8.14)
---------------------------------------------------------------------------
```

**After**:

```shell
---------------------------------------------------------------------------
Name (time in ms) Min Max
---------------------------------------------------------------------------
test_interface_performance[1] 14.3921 (1.0) 46.2064 (2.79)
test_interface_performance[16] 14.8669 (1.03) 16.5732 (1.0)
test_interface_performance[256] 15.8977 (1.10) 24.4618 (1.48)
test_interface_performance[4096] 18.7340 (1.30) 21.2899 (1.28)
---------------------------------------------------------------------------
```
2 changes: 2 additions & 0 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _process_type(
interfaces = _get_interfaces(cls)
fields = _get_fields(cls)
is_type_of = getattr(cls, "is_type_of", None)
resolve_type = getattr(cls, "resolve_type", None)

cls.__strawberry_definition__ = StrawberryObjectDefinition(
name=name,
Expand All @@ -151,6 +152,7 @@ def _process_type(
extend=extend,
_fields=fields,
is_type_of=is_type_of,
resolve_type=resolve_type,
)
# TODO: remove when deprecating _type_definition
DeprecatedDescriptor(
Expand Down
26 changes: 22 additions & 4 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generic,
Expand All @@ -20,6 +21,7 @@
from typing_extensions import Protocol

from graphql import (
GraphQLAbstractType,
GraphQLArgument,
GraphQLDirective,
GraphQLEnumType,
Expand All @@ -33,6 +35,8 @@
GraphQLObjectType,
GraphQLUnionType,
Undefined,
ValueNode,
default_type_resolver,
)
from graphql.language.directive_locations import DirectiveLocation

Expand Down Expand Up @@ -64,9 +68,7 @@
from strawberry.unset import UNSET
from strawberry.utils.await_maybe import await_maybe

from ..extensions.field_extension import (
build_field_extension_resolvers,
)
from ..extensions.field_extension import build_field_extension_resolvers
from . import compat
from .types.concrete_type import ConcreteType

Expand All @@ -77,7 +79,6 @@
GraphQLOutputType,
GraphQLResolveInfo,
GraphQLScalarType,
ValueNode,
)

from strawberry.custom_scalar import ScalarDefinition
Expand Down Expand Up @@ -426,6 +427,22 @@ def from_interface(
assert isinstance(graphql_interface, GraphQLInterfaceType) # For mypy
return graphql_interface

def _get_resolve_type():
if interface.resolve_type:
return interface.resolve_type

def resolve_type(
obj: Any, info: GraphQLResolveInfo, abstract_type: GraphQLAbstractType
) -> Union[Awaitable[Optional[str]], str, None]:
if isinstance(obj, interface.origin):
return obj.__strawberry_definition__.name
else:
# Revert to calling is_type_of for cases where a direct subclass
# of the interface is not returned (i.e. an ORM object)
return default_type_resolver(obj, info, abstract_type)

return resolve_type

graphql_interface = GraphQLInterfaceType(
name=interface_name,
fields=lambda: self.get_graphql_fields(interface),
Expand All @@ -434,6 +451,7 @@ def from_interface(
extensions={
GraphQLCoreConverter.DEFINITION_BACKREF: interface,
},
resolve_type=_get_resolve_type(),
)

self.type_map[interface_name] = ConcreteType(
Expand Down
6 changes: 5 additions & 1 deletion strawberry/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
from graphql import GraphQLAbstractType, GraphQLResolveInfo

from strawberry.field import StrawberryField

Expand All @@ -51,6 +51,9 @@ class StrawberryObjectDefinition(StrawberryType):
extend: bool
directives: Optional[Sequence[object]]
is_type_of: Optional[Callable[[Any, GraphQLResolveInfo], bool]]
resolve_type: Optional[
Callable[[Any, GraphQLResolveInfo, GraphQLAbstractType], str]
]

_fields: List[StrawberryField]

Expand Down Expand Up @@ -97,6 +100,7 @@ def copy_with(
description=self.description,
extend=self.extend,
is_type_of=self.is_type_of,
resolve_type=self.resolve_type,
_fields=fields,
concrete_of=self,
type_var_map=type_var_map,
Expand Down
167 changes: 153 additions & 14 deletions tests/schema/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import List
from typing import Any, List

import pytest
from pytest_mock import MockerFixture

import strawberry
from strawberry.types.types import StrawberryObjectDefinition
Expand Down Expand Up @@ -42,6 +43,7 @@ def assortment(self) -> List[Cheese]:
result = schema.execute_sync(query)

assert not result.errors
assert result.data is not None
assert result.data["assortment"] == [
{"name": "Asiago", "province": "Friuli"},
{"canton": "Vaud", "name": "Tomme"},
Expand Down Expand Up @@ -92,6 +94,7 @@ def always_error(self) -> Error:
result = schema.execute_sync(query)

assert not result.errors
assert result.data is not None
assert result.data["alwaysError"] == {
"message": "Password Too Short",
"field": "Password",
Expand All @@ -109,7 +112,7 @@ class Anime(Entity):
name: str

@classmethod
def is_type_of(cls, obj, _info) -> bool:
def is_type_of(cls, obj: Any, _) -> bool:
return isinstance(obj, AnimeORM)

@dataclass
Expand All @@ -120,19 +123,19 @@ class AnimeORM:
@strawberry.type
class Query:
@strawberry.field
def anime(self) -> Anime:
def anime(self) -> Entity:
return AnimeORM(id=1, name="One Piece") # type: ignore

schema = strawberry.Schema(query=Query)
schema = strawberry.Schema(query=Query, types=[Anime])

query = """{
anime { name }
anime { id ... on Anime { name } }
}"""

result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"anime": {"name": "One Piece"}}
assert result.data == {"anime": {"id": 1, "name": "One Piece"}}


def test_interface_explicit_type_resolution():
Expand All @@ -150,7 +153,7 @@ class Anime(Node):
name: str

@classmethod
def is_type_of(cls, obj, _info) -> bool:
def is_type_of(cls, obj: Any, _) -> bool:
return isinstance(obj, AnimeORM)

@strawberry.type
Expand All @@ -161,11 +164,17 @@ def node(self) -> Node:

schema = strawberry.Schema(query=Query, types=[Anime])

query = "{ node { __typename, id } }"
query = "{ node { __typename, id ... on Anime { name }} }"
result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"node": {"__typename": "Anime", "id": 1}}
assert result.data == {
"node": {
"__typename": "Anime",
"id": 1,
"name": "One Piece",
}
}


@pytest.mark.xfail(reason="We don't support returning dictionaries yet")
Expand All @@ -178,11 +187,6 @@ class Entity:
class Anime(Entity):
name: str

@dataclass
class AnimeORM:
id: int
name: str

@strawberry.type
class Query:
@strawberry.field
Expand Down Expand Up @@ -239,3 +243,138 @@ class Query:
assert origins == [InterfaceA, InterfaceB, Base]

strawberry.Schema(Query) # Final sanity check to ensure schema compiles


def test_interface_resolve_type(mocker: MockerFixture):
"""Check that the default implemenetation of `resolve_type` functions as expected.
In this test-case the default implementation of `resolve_type` defined in
`GraphQLCoreConverter.from_interface`, should immediately resolve the type of the
returned concrete object. A concrete object is defined as one that is an instance of
the interface it implements.
Before the default implementation of `resolve_type`, the `is_type_of` methods of all
specializations of an interface (in this case Anime & Movie) would be called. As
this needlessly reduces performance, this test checks if only `Anime.is_type_of` is
called when `Query.node` returns an `Anime` object.
"""

class IsTypeOfTester:
@classmethod
def is_type_of(cls, obj: Any, _) -> bool:
return isinstance(obj, cls)

spy_is_type_of = mocker.spy(IsTypeOfTester, "is_type_of")

@strawberry.interface
class Node:
id: int

@strawberry.type
class Anime(Node, IsTypeOfTester):
name: str

@strawberry.type
class Movie(Node):
title: str

@classmethod
def is_type_of(cls, *args: Any, **kwargs: Any) -> bool:
del args, kwargs
raise RuntimeError("Movie.is_type_of shouldn't have been called")

@strawberry.type
class Query:
@strawberry.field
def node(self) -> Node:
return Anime(id=1, name="One Pierce")

schema = strawberry.Schema(query=Query, types=[Anime, Movie])

query = "{ node { __typename, id } }"
result = schema.execute_sync(query)

assert not result.errors
assert result.data == {"node": {"__typename": "Anime", "id": 1}}
spy_is_type_of.assert_called_once()


def test_interface_specialized_resolve_type(mocker: MockerFixture):
"""Test that a specialized ``resolve_type`` is called."""

class InterfaceTester:
@classmethod
def resolve_type(cls, obj: Any, *args: Any, **kwargs: Any) -> str:
del args, kwargs
return obj._type_definition.name

spy_resolve_type = mocker.spy(InterfaceTester, "resolve_type")

@strawberry.interface
class Food(InterfaceTester):
id: int

@strawberry.type
class Fruit(Food):
name: str

@strawberry.type
class Query:
@strawberry.field
def food(self) -> Food:
return Fruit(id=1, name="strawberry")

schema = strawberry.Schema(query=Query, types=[Fruit])
result = schema.execute_sync("query { food { ... on Fruit { name } } }")

assert not result.errors
assert result.data == {"food": {"name": "strawberry"}}
spy_resolve_type.assert_called_once()


@pytest.mark.asyncio
async def test_derived_interface(mocker: MockerFixture):
"""Test if correct resolve_type is called on a derived interface."""

class NodeInterfaceTester:
@classmethod
def resolve_type(cls, obj: Any, *args: Any, **kwargs: Any) -> str:
del args, kwargs
return obj._type_definition.name

class NamedNodeInterfaceTester:
@classmethod
def resolve_type(cls, obj: Any, *args: Any, **kwargs: Any) -> str:
del args, kwargs
return obj._type_definition.name

spy_node_resolve_type = mocker.spy(NodeInterfaceTester, "resolve_type")
spy_named_node_resolve_type = mocker.spy(NamedNodeInterfaceTester, "resolve_type")

@strawberry.interface
class Node(NodeInterfaceTester):
id: int

@strawberry.interface
class NamedNode(NamedNodeInterfaceTester, Node):
name: str

@strawberry.type
class Person(NamedNode):
pass

@strawberry.type
class Query:
@strawberry.field
def friends(self) -> List[NamedNode]:
return [Person(id=1, name="foo"), Person(id=2, name="bar")]

schema = strawberry.Schema(Query, types=[Person])
result = await schema.execute("query { friends { name } }")

assert not result.errors
assert result.data == {"friends": [{"name": "foo"}, {"name": "bar"}]}

assert result.data is not None
assert spy_named_node_resolve_type.call_count == len(result.data["friends"])
spy_node_resolve_type.assert_not_called()

0 comments on commit 92e1151

Please sign in to comment.