Skip to content

Commit

Permalink
Codegen generic types (#3077)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Gilson <mgilson@lat.ai>
  • Loading branch information
mgilson and Matt Gilson authored Sep 15, 2023
1 parent 47578dd commit ecba2c3
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 3 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This fixes a bug where codegen would choke trying to find a field in the schema for a generic type.
24 changes: 22 additions & 2 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,32 @@ def _field_from_selection_set(
) -> GraphQLField:
assert selection.selection_set is not None

parent_type_name = parent_type.name

# Check if the parent type is generic.
# This seems to be tracked by `strawberry` in the `type_var_map`
# If the type is generic, then the strawberry generated schema
# naming convention is <GenericType,...><ClassName>
# The implementation here assumes that the `type_var_map` is ordered,
# but insertion order is maintained in python3.6+ (for CPython) and
# guaranteed for all python implementations in python3.7+, so that
# should be pretty safe.
if parent_type.type_var_map:
parent_type_name = (
"".join(
c.__name__ # type: ignore[union-attr]
for c in parent_type.type_var_map.values()
)
+ parent_type.name
)

selected_field = self.schema.get_field_for_type(
selection.name.value, parent_type.name
selection.name.value, parent_type_name
)

assert (
selected_field
), f"Couldn't find {parent_type.name}.{selection.name.value}"
), f"Couldn't find {parent_type_name}.{selection.name.value}"

selected_field_type, wrapper = self._unwrap_type(selected_field.type)
name = capitalize_first(to_camel_case(selection.name.value))
Expand Down
19 changes: 18 additions & 1 deletion tests/codegen/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import decimal
import enum
import random
from typing import TYPE_CHECKING, List, NewType, Optional, Union
from typing import TYPE_CHECKING, Generic, List, NewType, Optional, TypeVar, Union
from typing_extensions import Annotated
from uuid import UUID

Expand Down Expand Up @@ -35,6 +35,16 @@ class Animal:
age: int


LivingThing1 = TypeVar("LivingThing1")
LivingThing2 = TypeVar("LivingThing2")


@strawberry.type
class LifeContainer(Generic[LivingThing1, LivingThing2]):
items1: List[LivingThing1]
items2: List[LivingThing2]


PersonOrAnimal = Annotated[Union[Person, Animal], strawberry.union("PersonOrAnimal")]


Expand Down Expand Up @@ -110,6 +120,13 @@ def get_person_or_animal(self) -> Union[Person, Animal]:
p_or_a.age = 7
return p_or_a

@strawberry.field
def list_life() -> LifeContainer[Person, Animal]:
"""Get lists of living things."""
person = Person(name="Henry", age=10)
dinosaur = Animal(name="rex", age=66_000_000)
return LifeContainer([person], [dinosaur])


@strawberry.input
class BlogPostInput:
Expand Down
12 changes: 12 additions & 0 deletions tests/codegen/queries/generic_types.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
query ListLifeGeneric {
listLife {
items1 {
name
age
}
items2 {
name
age
}
}
}
16 changes: 16 additions & 0 deletions tests/codegen/snapshots/python/generic_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List

class ListLifeGenericResultListLifeItems1:
name: str
age: int

class ListLifeGenericResultListLifeItems2:
name: str
age: int

class ListLifeGenericResultListLife:
items1: List[ListLifeGenericResultListLifeItems1]
items2: List[ListLifeGenericResultListLifeItems2]

class ListLifeGenericResult:
list_life: ListLifeGenericResultListLife
18 changes: 18 additions & 0 deletions tests/codegen/snapshots/typescript/generic_types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
type ListLifeGenericResultListLifeItems1 = {
name: string
age: number
}

type ListLifeGenericResultListLifeItems2 = {
name: string
age: number
}

type ListLifeGenericResultListLife = {
items1: ListLifeGenericResultListLifeItems1[]
items2: ListLifeGenericResultListLifeItems2[]
}

type ListLifeGenericResult = {
list_life: ListLifeGenericResultListLife
}

0 comments on commit ecba2c3

Please sign in to comment.