Skip to content

Commit

Permalink
Fix incorrect query params type (#3558)
Browse files Browse the repository at this point in the history
* Fix assumption query params could be lists

* Add release file

* Test that variables are still successfully used

* Lint and format
  • Loading branch information
DoctorJohn authored Jul 8, 2024
1 parent d2c0fb4 commit c25da89
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 18 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release removes an unnecessary check from our internal GET query parsing logic making it simpler and (insignificantly) faster.
4 changes: 2 additions & 2 deletions strawberry/flask/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union, cast

from flask import Request, Response, render_template_string, request
from flask.views import View
Expand All @@ -26,7 +26,7 @@ def __init__(self, request: Request) -> None:
self.request = request

@property
def query_params(self) -> Mapping[str, Union[str, Optional[List[str]]]]:
def query_params(self) -> QueryParams:
return self.request.args.to_dict()

@property
Expand Down
9 changes: 2 additions & 7 deletions strawberry/http/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE, get_graphql_ide_html
from strawberry.http.types import HTTPMethod
from strawberry.http.types import HTTPMethod, QueryParams

from .exceptions import HTTPException
from .typevars import Request
Expand Down Expand Up @@ -50,17 +50,12 @@ def parse_json(self, data: Union[str, bytes]) -> Any:
def encode_json(self, response_data: GraphQLHTTPResponse) -> str:
return json.dumps(response_data)

def parse_query_params(
self, params: Mapping[str, Optional[Union[str, List[str]]]]
) -> Dict[str, Any]:
def parse_query_params(self, params: QueryParams) -> Dict[str, Any]:
params = dict(params)

if "variables" in params:
variables = params["variables"]

if isinstance(variables, list):
variables = variables[0]

if variables:
params["variables"] = self.parse_json(variables)

Expand Down
4 changes: 2 additions & 2 deletions strawberry/http/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, List, Mapping, Optional, Union
from typing import Any, Mapping, Optional
from typing_extensions import Literal, TypedDict

HTTPMethod = Literal[
"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE"
]

QueryParams = Mapping[str, Optional[Union[str, List[str]]]]
QueryParams = Mapping[str, Optional[str]]


class FormData(TypedDict):
Expand Down
8 changes: 1 addition & 7 deletions strawberry/sanic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
Optional,
Type,
Expand Down Expand Up @@ -43,12 +42,7 @@ def query_params(self) -> QueryParams:
# the keys are the unique variable names and the values are lists
# of values for each variable name. To ensure consistency, we're
# enforcing the use of the first value in each list.

args = cast(
Dict[str, Optional[List[str]]],
self.request.get_args(keep_blank_values=True),
)

args = self.request.get_args(keep_blank_values=True)
return {k: args.get(k, None) for k in args}

@property
Expand Down
9 changes: 9 additions & 0 deletions tests/http/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ async def test_passing_invalid_json_get(http_client: HttpClient):
assert "Unable to parse request body as JSON" in response.text


async def test_query_parameters_are_never_interpreted_as_list(http_client: HttpClient):
response = await http_client.get(
url='/graphql?query=query($name: String!) { hello(name: $name) }&variables={"name": "Jake"}&variables={"name": "Jake"}',
)

assert response.status_code == 200
assert response.json["data"] == {"hello": "Hello Jake"}


async def test_missing_query(http_client: HttpClient):
response = await http_client.post(
url="/graphql",
Expand Down

0 comments on commit c25da89

Please sign in to comment.