Skip to content

Commit

Permalink
Add schema codegen command (#3096)
Browse files Browse the repository at this point in the history
* Initial POC for codegeneration

* Add support for descriptions

* Rename test file

* Add support for interfaces

* Add support for multi-line descriptions

* Add support for snake casing names

* For the future

* Initial cli, add support for explicit schema

* Add support for enums

* Add todo

* Add support for unions

* Add support for inputs

* Add support for extensions

* Add tests for extension

* fix snake case

* Add support for scalars

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Mypy fixes

* Fix backwards compatibility

* Add support for generating the schema object

* Add release file

* Fix tests

* Improve coverage

* Add basic docs

* Add link

* Add CLI tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
patrick91 and pre-commit-ci[bot] authored Sep 19, 2023
1 parent 3f219a1 commit 70357c0
Show file tree
Hide file tree
Showing 30 changed files with 1,355 additions and 20 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ repos:
- id: end-of-file-fixer
exclude: ^tests/relay/snapshots
- id: check-toml
- id: no-commit-to-branch
args: ['--branch', 'main']

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.15.0
Expand Down
40 changes: 40 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Release type: minor

This release adds support for generating Strawberry types from SDL files. For example, given the following SDL file:

```graphql
type Query {
user: User
}

type User {
id: ID!
name: String!
}
```

you can run

```bash
strawberry schema-codegen schema.graphql
```

to generate the following Python code:

```python
import strawberry


@strawberry.type
class Query:
user: User | None


@strawberry.type
class User:
id: strawberry.ID
name: str


schema = strawberry.Schema(query=Query)
```
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

## Codegen

- [Schema codegen](./codegen/schema-codegen.md)
- [Query codegen](./codegen/query-codegen.md)

## [Extensions](./extensions)
Expand Down
46 changes: 46 additions & 0 deletions docs/codegen/schema-codegen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
---
title: Schema codegen
---

# Schema codegen

Strawberry supports code generation from SDL files.

Let's assume we have the following SDL file:

```graphql
type Query {
user: User
}

type User {
id: ID!
name: String!
}
```

by running the following command:

```shell
strawberry schema-codegen schema.graphql
```

we'll get the following output:

```python
import strawberry


@strawberry.type
class Query:
user: User | None


@strawberry.type
class User:
id: strawberry.ID
name: str


schema = strawberry.Schema(query=Query)
```
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ ignore = [
"PT001",
"PT023",

# this is pretty much handled by black
"E501",

# enable these, we have some in tests
"B006",
"PT004",
Expand Down Expand Up @@ -347,17 +350,9 @@ src = ["strawberry", "tests"]

[tool.ruff.per-file-ignores]
"strawberry/schema/types/concrete_type.py" = ["TCH002"]
"tests/federation/printer/*" = ["E501"]
"tests/test_printer/test_basic.py" = ["E501"]
"tests/pyright/test_federation.py" = ["E501"]
"tests/codemods/*" = ["E501"]
"tests/test_printer/test_schema_directives.py" = ["E501"]
"tests/*" = ["RSE102", "SLF001", "TCH001", "TCH002", "TCH003", "ANN001", "ANN201", "PLW0603", "PLC1901", "S603", "S607", "B018"]
"strawberry/extensions/tracing/__init__.py" = ["TCH004"]
"tests/http/clients/__init__.py" = ["F401"]
"tests/schema/test_scalars.py" = ["E501"]
"tests/schema/test_basic.py" = ["E501"]
"tests/schema/test_union.py" = ["E501"]

[tool.ruff.isort]
known-first-party = ["strawberry"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self._ws = web.WebSocketResponse(protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])

async def get_context(self) -> Any:
return await self._get_context(request=self._request, response=self._ws) # type: ignore # noqa: E501
return await self._get_context(request=self._request, response=self._ws) # type: ignore

async def get_root_value(self) -> Any:
return await self._get_root_value(request=self._request)
Expand Down
4 changes: 2 additions & 2 deletions strawberry/channels/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def channel_listen(
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html " # noqa:E501
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html "
"for more information"
)

Expand Down Expand Up @@ -176,7 +176,7 @@ async def listen_to_channel(
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html " # noqa:E501
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html "
"for more information"
)

Expand Down
1 change: 1 addition & 0 deletions strawberry/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .commands.export_schema import export_schema as export_schema # noqa
from .commands.server import server as server # noqa
from .commands.upgrade import upgrade as upgrade # noqa
from .commands.schema_codegen import schema_codegen as schema_codegen # noqa

from .app import app

Expand Down
32 changes: 32 additions & 0 deletions strawberry/cli/commands/schema_codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional

import typer

from strawberry.cli.app import app
from strawberry.schema_codegen import codegen


@app.command(help="Generate code from a query")
def schema_codegen(
schema: Path = typer.Argument(exists=True),
output: Optional[Path] = typer.Option(
None,
"-o",
"--output",
file_okay=True,
dir_okay=False,
writable=True,
resolve_path=True,
),
) -> None:
generated_output = codegen(schema.read_text())

if output is None:
typer.echo(generated_output)
return

output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(generated_output)

typer.echo(f"Code generated at `{output.name}`")
4 changes: 2 additions & 2 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,14 @@ def _get_field_type(
not isinstance(field_type, StrawberryType)
and field_type in self.schema.schema_converter.scalar_registry
):
field_type = self.schema.schema_converter.scalar_registry[field_type] # type: ignore # noqa: E501
field_type = self.schema.schema_converter.scalar_registry[field_type] # type: ignore

if isinstance(field_type, ScalarWrapper):
python_type = field_type.wrap
if hasattr(python_type, "__supertype__"):
python_type = python_type.__supertype__

return self._collect_scalar(field_type._scalar_definition, python_type) # type: ignore # noqa: E501
return self._collect_scalar(field_type._scalar_definition, python_type) # type: ignore

if isinstance(field_type, ScalarDefinition):
return self._collect_scalar(field_type, None)
Expand Down
2 changes: 1 addition & 1 deletion strawberry/codemods/annotated_unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def leave_union_call(
union_node = cst.Subscript(
value=cst.Name(value="Union"),
slice=[
cst.SubscriptElement(slice=cst.Index(value=t.value)) for t in types # type: ignore # noqa: E501
cst.SubscriptElement(slice=cst.Index(value=t.value)) for t in types # type: ignore
],
)

Expand Down
2 changes: 1 addition & 1 deletion strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]:
# https://docs.djangoproject.com/en/3.1/topics/async/#async-views

view = super().as_view(**initkwargs)
view._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined] # noqa: E501
view._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]
return view

async def get_root_value(self, request: HttpRequest) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion strawberry/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _python_name(self) -> Optional[str]:
def _set_python_name(self, name: str) -> None:
self.name = name

python_name: str = property(_python_name, _set_python_name) # type: ignore[assignment] # noqa: E501
python_name: str = property(_python_name, _set_python_name) # type: ignore[assignment]

@property
def base_resolver(self) -> Optional[StrawberryResolver]:
Expand Down
Loading

0 comments on commit 70357c0

Please sign in to comment.