From 13949fd6353cfc9c6a1a94017cc0de6c1dfda7e9 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 5 Apr 2022 20:38:57 -0500 Subject: [PATCH] Initial support for batching --- strawberry/django/views.py | 42 +++++++++++++++++++++++++++----------- strawberry/http.py | 37 ++++++++++++++++++++++----------- tests/django/test_views.py | 21 +++++++++++++++++++ 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 04886a7606..cdd71e6743 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -20,6 +20,7 @@ from strawberry.file_uploads.utils import replace_placeholders_with_files from strawberry.http import ( GraphQLHTTPResponse, + GraphQLRequest, GraphQLRequestData, parse_request_data, process_result, @@ -56,7 +57,7 @@ def __init__( self.subscriptions_enabled = subscriptions_enabled super().__init__(**kwargs) - def parse_body(self, request) -> Dict[str, Any]: + def parse_body(self, request) -> GraphQLRequestData: if request.content_type.startswith("multipart/form-data"): data = json.loads(request.POST.get("operations", "{}")) files_map = json.loads(request.POST.get("map", "{}")) @@ -73,7 +74,7 @@ def is_request_allowed(self, request: HttpRequest) -> bool: def should_render_graphiql(self, request: HttpRequest) -> bool: return "text/html" in request.headers.get("Accept", "") - def get_request_data(self, request: HttpRequest) -> GraphQLRequestData: + def get_request_data(self, request: HttpRequest) -> GraphQLRequest: try: data = self.parse_body(request) except json.decoder.JSONDecodeError: @@ -118,6 +119,7 @@ def _create_response( response_data, encoder=self.json_encoder, json_dumps_params=self.json_dumps_params, + safe=False, ) for name, value in sub_response.items(): @@ -146,6 +148,8 @@ def process_result( @method_decorator(csrf_exempt) def dispatch(self, request, *args, **kwargs): + assert self.schema + if not self.is_request_allowed(request): return HttpResponseNotAllowed( ["GET", "POST"], "GraphQL only supports GET and POST requests." @@ -159,17 +163,31 @@ def dispatch(self, request, *args, **kwargs): sub_response = TemporalHttpResponse() context = self.get_context(request, response=sub_response) - assert self.schema - - result = self.schema.execute_sync( - request_data.query, - root_value=self.get_root_value(request), - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - ) + if isinstance(request_data, list): + response_data = [ + self.process_result( + request=request, + result=self.schema.execute_sync( + data.query, + root_value=self.get_root_value(request), + variable_values=data.variables, + context_value=context, + operation_name=data.operation_name, + ), + ) + for data in request_data + ] + else: + + result = self.schema.execute_sync( + request_data.query, + root_value=self.get_root_value(request), + variable_values=request_data.variables, + context_value=context, + operation_name=request_data.operation_name, + ) - response_data = self.process_result(request=request, result=result) + response_data = self.process_result(request=request, result=result) return self._create_response( response_data=response_data, sub_response=sub_response diff --git a/strawberry/http.py b/strawberry/http.py index aa4086c434..b55fde008f 100644 --- a/strawberry/http.py +++ b/strawberry/http.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict from graphql.error.graphql_error import format_error as format_graphql_error @@ -9,6 +9,12 @@ from strawberry.types import ExecutionResult +class GraphQLRequestData(TypedDict, total=False): + query: Required[str] + variables: Optional[Dict[str, Any]] + operation_name: Optional[str] + + class GraphQLHTTPResponse(TypedDict, total=False): data: Optional[Dict[str, Any]] errors: Optional[List[Any]] @@ -27,20 +33,27 @@ def process_result(result: ExecutionResult) -> GraphQLHTTPResponse: @dataclass -class GraphQLRequestData: +class GraphQLRequest: query: str variables: Optional[Dict[str, Any]] operation_name: Optional[str] + @classmethod + def from_dict(cls, data: GraphQLRequestData) -> "GraphQLRequest": + if "query" not in data: + raise MissingQueryError() + + return GraphQLRequest( + query=data["query"], + variables=data.get("variables"), + operation_name=data.get("operation_name"), + ) -def parse_request_data(data: Dict) -> GraphQLRequestData: - if "query" not in data: - raise MissingQueryError() - result = GraphQLRequestData( - query=data["query"], - variables=data.get("variables"), - operation_name=data.get("operationName"), - ) +def parse_request_data( + data: Union[GraphQLRequestData, List[GraphQLRequestData]], +) -> Union[GraphQLRequest, List[GraphQLRequest]]: + if isinstance(data, list): + return [GraphQLRequest.from_dict(d) for d in data] - return result + return GraphQLRequest.from_dict(data) diff --git a/tests/django/test_views.py b/tests/django/test_views.py index 1e4bfe3595..f7c6d4f05d 100644 --- a/tests/django/test_views.py +++ b/tests/django/test_views.py @@ -374,3 +374,24 @@ class CustomGraphQLView(GraphQLView): response2 = CustomGraphQLView.as_view(schema=schema)(request) assert response1.content == response2.content + + +def test_supports_batch_queries(): + query = "{ hello }" + + factory = RequestFactory() + request = factory.post( + "/graphql/", + [{"query": query}, {"query": query}], + content_type="application/json", + ) + + response = GraphQLView.as_view(schema=schema)(request) + + data = json.loads(response.content.decode()) + + assert response.status_code == 200 + assert data == [ + {"data": {"hello": "strawberry"}}, + {"data": {"hello": "strawberry"}}, + ]