Skip to content

Commit

Permalink
Initial support for batching
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Apr 6, 2022
1 parent c64a9f9 commit 13949fd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 24 deletions.
42 changes: 30 additions & 12 deletions strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from strawberry.file_uploads.utils import replace_placeholders_with_files
from strawberry.http import (
GraphQLHTTPResponse,
GraphQLRequest,
GraphQLRequestData,
parse_request_data,
process_result,
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self.subscriptions_enabled = subscriptions_enabled
super().__init__(**kwargs)

def parse_body(self, request) -> Dict[str, Any]:
def parse_body(self, request) -> GraphQLRequestData:
if request.content_type.startswith("multipart/form-data"):
data = json.loads(request.POST.get("operations", "{}"))
files_map = json.loads(request.POST.get("map", "{}"))
Expand All @@ -73,7 +74,7 @@ def is_request_allowed(self, request: HttpRequest) -> bool:
def should_render_graphiql(self, request: HttpRequest) -> bool:
return "text/html" in request.headers.get("Accept", "")

def get_request_data(self, request: HttpRequest) -> GraphQLRequestData:
def get_request_data(self, request: HttpRequest) -> GraphQLRequest:
try:
data = self.parse_body(request)
except json.decoder.JSONDecodeError:
Expand Down Expand Up @@ -118,6 +119,7 @@ def _create_response(
response_data,
encoder=self.json_encoder,
json_dumps_params=self.json_dumps_params,
safe=False,
)

for name, value in sub_response.items():
Expand Down Expand Up @@ -146,6 +148,8 @@ def process_result(

@method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs):
assert self.schema

if not self.is_request_allowed(request):
return HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
Expand All @@ -159,17 +163,31 @@ def dispatch(self, request, *args, **kwargs):
sub_response = TemporalHttpResponse()
context = self.get_context(request, response=sub_response)

assert self.schema

result = self.schema.execute_sync(
request_data.query,
root_value=self.get_root_value(request),
variable_values=request_data.variables,
context_value=context,
operation_name=request_data.operation_name,
)
if isinstance(request_data, list):
response_data = [
self.process_result(
request=request,
result=self.schema.execute_sync(
data.query,
root_value=self.get_root_value(request),
variable_values=data.variables,
context_value=context,
operation_name=data.operation_name,
),
)
for data in request_data
]
else:

result = self.schema.execute_sync(
request_data.query,
root_value=self.get_root_value(request),
variable_values=request_data.variables,
context_value=context,
operation_name=request_data.operation_name,
)

response_data = self.process_result(request=request, result=result)
response_data = self.process_result(request=request, result=result)

return self._create_response(
response_data=response_data, sub_response=sub_response
Expand Down
37 changes: 25 additions & 12 deletions strawberry/http.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict

from graphql.error.graphql_error import format_error as format_graphql_error

from strawberry.exceptions import MissingQueryError
from strawberry.types import ExecutionResult


class GraphQLRequestData(TypedDict, total=False):
query: Required[str]
variables: Optional[Dict[str, Any]]
operation_name: Optional[str]


class GraphQLHTTPResponse(TypedDict, total=False):
data: Optional[Dict[str, Any]]
errors: Optional[List[Any]]
Expand All @@ -27,20 +33,27 @@ def process_result(result: ExecutionResult) -> GraphQLHTTPResponse:


@dataclass
class GraphQLRequestData:
class GraphQLRequest:
query: str
variables: Optional[Dict[str, Any]]
operation_name: Optional[str]

@classmethod
def from_dict(cls, data: GraphQLRequestData) -> "GraphQLRequest":
if "query" not in data:
raise MissingQueryError()

return GraphQLRequest(
query=data["query"],
variables=data.get("variables"),
operation_name=data.get("operation_name"),
)

def parse_request_data(data: Dict) -> GraphQLRequestData:
if "query" not in data:
raise MissingQueryError()

result = GraphQLRequestData(
query=data["query"],
variables=data.get("variables"),
operation_name=data.get("operationName"),
)
def parse_request_data(
data: Union[GraphQLRequestData, List[GraphQLRequestData]],
) -> Union[GraphQLRequest, List[GraphQLRequest]]:
if isinstance(data, list):
return [GraphQLRequest.from_dict(d) for d in data]

return result
return GraphQLRequest.from_dict(data)
21 changes: 21 additions & 0 deletions tests/django/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,24 @@ class CustomGraphQLView(GraphQLView):

response2 = CustomGraphQLView.as_view(schema=schema)(request)
assert response1.content == response2.content


def test_supports_batch_queries():
query = "{ hello }"

factory = RequestFactory()
request = factory.post(
"/graphql/",
[{"query": query}, {"query": query}],
content_type="application/json",
)

response = GraphQLView.as_view(schema=schema)(request)

data = json.loads(response.content.decode())

assert response.status_code == 200
assert data == [
{"data": {"hello": "strawberry"}},
{"data": {"hello": "strawberry"}},
]

0 comments on commit 13949fd

Please sign in to comment.