Skip to content

Commit

Permalink
Rename add_graphql.py to autoregistration.py
Browse files Browse the repository at this point in the history
- check if model is already added to strawberry_models.
- rename `EnumList` to `EnumDict`.
- update fail check in unit tests.
  • Loading branch information
tjeerddie committed Jul 11, 2023
1 parent 53f1789 commit 726d787
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 31 deletions.
4 changes: 2 additions & 2 deletions orchestrator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from orchestrator.forms import FormError
from orchestrator.graphql import (
GRAPHQL_MODELS,
EnumList,
EnumDict,
Mutation,
Query,
add_class_to_strawberry,
Expand Down Expand Up @@ -189,7 +189,7 @@ def register_subscription_models(product_to_subscription_model_mapping: Dict[str

def register_graphql(self: "OrchestratorCore", query: Any = Query, mutation: Any = Mutation) -> None:
strawberry_models = GRAPHQL_MODELS
strawberry_enums: EnumList = {}
strawberry_enums: EnumDict = {}
products = {
product_type.__base_type__.__name__: product_type.__base_type__
for product_type in SUBSCRIPTION_MODEL_REGISTRY.values()
Expand Down
6 changes: 3 additions & 3 deletions orchestrator/graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from orchestrator.graphql.add_graphql import (
EnumList,
from orchestrator.graphql.autoregistration import (
EnumDict,
add_class_to_strawberry,
graphql_subscription_name,
)
Expand All @@ -37,7 +37,7 @@
"get_context",
"graphql_router",
"create_graphql_router",
"EnumList",
"EnumDict",
"add_class_to_strawberry",
"graphql_subscription_name",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
from typing import Any, Type

import strawberry
import structlog
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic

from orchestrator.domain.base import DomainModel, get_depends_on_product_block_type_list
from orchestrator.graphql.schema import StrawberryModelType
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
from orchestrator.utils.helpers import to_camel

EnumList = dict[str, EnumMeta]
logger = structlog.get_logger(__name__)

EnumDict = dict[str, EnumMeta]


def create_strawberry_enum(enum: Any) -> EnumMeta:
return strawberry.enum(enum)


def is_not_strawberry_enum(key: str, strawberry_enums: EnumList) -> bool:
def is_not_strawberry_enum(key: str, strawberry_enums: EnumDict) -> bool:
return key not in strawberry_enums


Expand Down Expand Up @@ -75,7 +78,7 @@ def create_block_strawberry_type(
return strawberry_wrapper(new_type)


def create_strawberry_enums(model: Type[DomainModel], strawberry_enums: EnumList) -> EnumList:
def create_strawberry_enums(model: Type[DomainModel], strawberry_enums: EnumDict) -> EnumDict:
enums = {
key: field
for key, field in model._non_product_block_fields_.items()
Expand All @@ -88,18 +91,23 @@ def add_class_to_strawberry(
model_name: str,
model: Type[DomainModel],
strawberry_models: StrawberryModelType,
strawberry_enums: EnumList,
strawberry_enums: EnumDict,
with_interface: bool = False,
) -> None:
if model_name in strawberry_models:
logger.debug("Skip already registered strawberry model", model=repr(model), strawberry_name=model_name)
return
logger.debug("Registering strawberry model", model=repr(model), strawberry_name=model_name)

strawberry_enums = create_strawberry_enums(model, strawberry_enums)

product_blocks_types_in_model = get_depends_on_product_block_type_list(model._get_depends_on_product_block_types())
for field in product_blocks_types_in_model:
if is_not_strawberry_type(field.__name__, strawberry_models) and field.__name__ != model_name:
add_class_to_strawberry(field.__name__, field, strawberry_models, strawberry_enums)
graphql_field_name = graphql_name(field.__name__)
if is_not_strawberry_type(graphql_field_name, strawberry_models) and graphql_field_name != model_name:
add_class_to_strawberry(graphql_field_name, field, strawberry_models, strawberry_enums)

strawberry_name = graphql_name(model_name)
strawberry_type_convert_function = (
create_subscription_strawberry_type if with_interface else create_block_strawberry_type
)
strawberry_models[strawberry_name] = strawberry_type_convert_function(strawberry_name, model)
strawberry_models[model_name] = strawberry_type_convert_function(model_name, model)
2 changes: 1 addition & 1 deletion orchestrator/graphql/resolvers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def has_details(selection: Selection) -> bool:


def get_subscription_details(subscription: SubscriptionTable) -> SubscriptionInterface:
from orchestrator.graphql.add_graphql import graphql_subscription_name
from orchestrator.graphql.autoregistration import graphql_subscription_name
from orchestrator.graphql.schema import GRAPHQL_MODELS

subscription_details = SubscriptionModel.from_subscription(subscription.subscription_id)
Expand Down
30 changes: 13 additions & 17 deletions test/unit_tests/graphql/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,20 +889,18 @@ def test_single_subscription_with_in_use_by_subscriptions(
assert result_in_use_by_ids == expected_in_use_by_ids


def expect_fail_test_if_too_many_duplicate_types_in_interface():
from orchestrator.graphql.schemas.subscription import SubscriptionInterface

product_one_sub_classes = [cl for cl in SubscriptionInterface.__subclasses__() if "ProductOne" in cl.__name__]
if len(product_one_sub_classes) > 1:
pytest.xfail("Test breaks when graphql interface has duplicate graphql types as subtype")
def expect_fail_test_if_too_many_duplicate_types_in_interface(result):
if "errors" in result and "but got: <ProductBlockModelGraphql instance>." in result["errors"][0]["message"]:
pytest.xfail(
"Test fails with all tests because classes are re-created every test and are not recognized as the same class"
)


def test_subscriptions_product_generic_one(
fastapi_app_graphql,
test_client,
product_type_1_subscriptions_factory,
):
expect_fail_test_if_too_many_duplicate_types_in_interface()
# when

subscriptions = product_type_1_subscriptions_factory(30)
Expand All @@ -914,11 +912,12 @@ def test_subscriptions_product_generic_one(

assert HTTPStatus.OK == response.status_code
result = response.json()
expect_fail_test_if_too_many_duplicate_types_in_interface(result)

subscriptions_data = result["data"]["subscriptions"]
subscriptions = subscriptions_data["page"]
pageinfo = subscriptions_data["pageInfo"]

assert "errors" not in result
assert len(subscriptions) == 1
assert pageinfo == {
"hasPreviousPage": False,
Expand All @@ -937,8 +936,6 @@ def test_single_subscription_product_list_union_type(
test_client,
product_sub_list_union_subscription_1,
):
expect_fail_test_if_too_many_duplicate_types_in_interface()

# when

subscription_id = str(product_sub_list_union_subscription_1)
Expand All @@ -949,11 +946,12 @@ def test_single_subscription_product_list_union_type(

assert HTTPStatus.OK == response.status_code
result = response.json()
expect_fail_test_if_too_many_duplicate_types_in_interface(result)

subscriptions_data = result["data"]["subscriptions"]
subscriptions = subscriptions_data["page"]
pageinfo = subscriptions_data["pageInfo"]

assert "errors" not in result
assert len(subscriptions) == 1
assert pageinfo == {
"hasPreviousPage": False,
Expand All @@ -976,8 +974,6 @@ def test_single_subscription_product_list_union_type_provisioning_subscription(
test_client,
product_sub_list_union_subscription_1,
):
expect_fail_test_if_too_many_duplicate_types_in_interface()

# when

subscription = SubscriptionModel.from_subscription(product_sub_list_union_subscription_1)
Expand All @@ -992,11 +988,12 @@ def test_single_subscription_product_list_union_type_provisioning_subscription(

assert HTTPStatus.OK == response.status_code
result = response.json()
expect_fail_test_if_too_many_duplicate_types_in_interface(result)

subscriptions_data = result["data"]["subscriptions"]
subscriptions = subscriptions_data["page"]
pageinfo = subscriptions_data["pageInfo"]

assert "errors" not in result
assert len(subscriptions) == 1
assert pageinfo == {
"hasPreviousPage": False,
Expand All @@ -1019,8 +1016,6 @@ def test_single_subscription_product_list_union_type_terminated_subscription(
test_client,
product_sub_list_union_subscription_1,
):
expect_fail_test_if_too_many_duplicate_types_in_interface()

# when

subscription = SubscriptionModel.from_subscription(product_sub_list_union_subscription_1)
Expand All @@ -1035,11 +1030,12 @@ def test_single_subscription_product_list_union_type_terminated_subscription(

assert HTTPStatus.OK == response.status_code
result = response.json()
expect_fail_test_if_too_many_duplicate_types_in_interface(result)

subscriptions_data = result["data"]["subscriptions"]
subscriptions = subscriptions_data["page"]
pageinfo = subscriptions_data["pageInfo"]

assert "errors" not in result
assert len(subscriptions) == 1
assert pageinfo == {
"hasPreviousPage": False,
Expand Down

0 comments on commit 726d787

Please sign in to comment.