diff --git a/orchestrator/app.py b/orchestrator/app.py index ba44e7a25..ee541daa1 100644 --- a/orchestrator/app.py +++ b/orchestrator/app.py @@ -44,7 +44,7 @@ from orchestrator.forms import FormError from orchestrator.graphql import ( GRAPHQL_MODELS, - EnumList, + EnumDict, Mutation, Query, add_class_to_strawberry, @@ -189,7 +189,7 @@ def register_subscription_models(product_to_subscription_model_mapping: Dict[str def register_graphql(self: "OrchestratorCore", query: Any = Query, mutation: Any = Mutation) -> None: strawberry_models = GRAPHQL_MODELS - strawberry_enums: EnumList = {} + strawberry_enums: EnumDict = {} products = { product_type.__base_type__.__name__: product_type.__base_type__ for product_type in SUBSCRIPTION_MODEL_REGISTRY.values() diff --git a/orchestrator/graphql/__init__.py b/orchestrator/graphql/__init__.py index 521bcbf98..a66ca0e58 100644 --- a/orchestrator/graphql/__init__.py +++ b/orchestrator/graphql/__init__.py @@ -10,8 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from orchestrator.graphql.add_graphql import ( - EnumList, +from orchestrator.graphql.autoregistration import ( + EnumDict, add_class_to_strawberry, graphql_subscription_name, ) @@ -37,7 +37,7 @@ "get_context", "graphql_router", "create_graphql_router", - "EnumList", + "EnumDict", "add_class_to_strawberry", "graphql_subscription_name", ] diff --git a/orchestrator/graphql/add_graphql.py b/orchestrator/graphql/autoregistration.py similarity index 81% rename from orchestrator/graphql/add_graphql.py rename to orchestrator/graphql/autoregistration.py index 2f8b80c1e..ba6efc7a3 100644 --- a/orchestrator/graphql/add_graphql.py +++ b/orchestrator/graphql/autoregistration.py @@ -16,6 +16,7 @@ from typing import Any, Type import strawberry +import structlog from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic from orchestrator.domain.base import DomainModel, get_depends_on_product_block_type_list @@ -23,14 +24,16 @@ from orchestrator.graphql.schemas.subscription import SubscriptionInterface from orchestrator.utils.helpers import to_camel -EnumList = dict[str, EnumMeta] +logger = structlog.get_logger(__name__) + +EnumDict = dict[str, EnumMeta] def create_strawberry_enum(enum: Any) -> EnumMeta: return strawberry.enum(enum) -def is_not_strawberry_enum(key: str, strawberry_enums: EnumList) -> bool: +def is_not_strawberry_enum(key: str, strawberry_enums: EnumDict) -> bool: return key not in strawberry_enums @@ -75,7 +78,7 @@ def create_block_strawberry_type( return strawberry_wrapper(new_type) -def create_strawberry_enums(model: Type[DomainModel], strawberry_enums: EnumList) -> EnumList: +def create_strawberry_enums(model: Type[DomainModel], strawberry_enums: EnumDict) -> EnumDict: enums = { key: field for key, field in model._non_product_block_fields_.items() @@ -88,18 +91,23 @@ def add_class_to_strawberry( model_name: str, model: Type[DomainModel], strawberry_models: StrawberryModelType, - strawberry_enums: EnumList, + strawberry_enums: EnumDict, with_interface: bool = False, ) -> None: + if model_name in strawberry_models: + logger.debug("Skip already registered strawberry model", model=repr(model), strawberry_name=model_name) + return + logger.debug("Registering strawberry model", model=repr(model), strawberry_name=model_name) + strawberry_enums = create_strawberry_enums(model, strawberry_enums) product_blocks_types_in_model = get_depends_on_product_block_type_list(model._get_depends_on_product_block_types()) for field in product_blocks_types_in_model: - if is_not_strawberry_type(field.__name__, strawberry_models) and field.__name__ != model_name: - add_class_to_strawberry(field.__name__, field, strawberry_models, strawberry_enums) + graphql_field_name = graphql_name(field.__name__) + if is_not_strawberry_type(graphql_field_name, strawberry_models) and graphql_field_name != model_name: + add_class_to_strawberry(graphql_field_name, field, strawberry_models, strawberry_enums) - strawberry_name = graphql_name(model_name) strawberry_type_convert_function = ( create_subscription_strawberry_type if with_interface else create_block_strawberry_type ) - strawberry_models[strawberry_name] = strawberry_type_convert_function(strawberry_name, model) + strawberry_models[model_name] = strawberry_type_convert_function(model_name, model) diff --git a/orchestrator/graphql/resolvers/subscription.py b/orchestrator/graphql/resolvers/subscription.py index 27b544514..57aabf084 100644 --- a/orchestrator/graphql/resolvers/subscription.py +++ b/orchestrator/graphql/resolvers/subscription.py @@ -66,7 +66,7 @@ def has_details(selection: Selection) -> bool: def get_subscription_details(subscription: SubscriptionTable) -> SubscriptionInterface: - from orchestrator.graphql.add_graphql import graphql_subscription_name + from orchestrator.graphql.autoregistration import graphql_subscription_name from orchestrator.graphql.schema import GRAPHQL_MODELS subscription_details = SubscriptionModel.from_subscription(subscription.subscription_id) diff --git a/orchestrator/graphql/types.py b/orchestrator/graphql/types.py index 0e6c544dc..9698c7ac1 100644 --- a/orchestrator/graphql/types.py +++ b/orchestrator/graphql/types.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Awaitable from ipaddress import IPv4Address, IPv4Interface, IPv6Address, IPv6Interface # Map some Orchestrator types to scalars @@ -18,6 +19,7 @@ import strawberry from graphql import GraphQLError +from starlette.requests import Request from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper from strawberry.types import Info from strawberry.types.info import RootValueType @@ -38,7 +40,11 @@ def serialize_vlan(vlan: VlanRanges) -> List[Tuple[int, int]]: class OrchestratorContext(OauthContext): - def __init__(self, get_current_user: Callable[[], OIDCUserModel], get_opa_decision: Callable[[str], bool]): + def __init__( + self, + get_current_user: Callable[[Request], Awaitable[OIDCUserModel]], + get_opa_decision: Callable[[str, OIDCUserModel], Awaitable[Union[bool, None]]], + ): self.errors: list[GraphQLError] = [] super().__init__(get_current_user, get_opa_decision) diff --git a/test/unit_tests/graphql/test_subscriptions.py b/test/unit_tests/graphql/test_subscriptions.py index b34431312..bea58d15e 100644 --- a/test/unit_tests/graphql/test_subscriptions.py +++ b/test/unit_tests/graphql/test_subscriptions.py @@ -889,12 +889,11 @@ def test_single_subscription_with_in_use_by_subscriptions( assert result_in_use_by_ids == expected_in_use_by_ids -def expect_fail_test_if_too_many_duplicate_types_in_interface(): - from orchestrator.graphql.schemas.subscription import SubscriptionInterface - - product_one_sub_classes = [cl for cl in SubscriptionInterface.__subclasses__() if "ProductOne" in cl.__name__] - if len(product_one_sub_classes) > 1: - pytest.xfail("Test breaks when graphql interface has duplicate graphql types as subtype") +def expect_fail_test_if_too_many_duplicate_types_in_interface(result): + if "errors" in result and "but got: ." in result["errors"][0]["message"]: + pytest.xfail( + "Test fails with all tests because classes are re-created every test and are not recognized as the same class" + ) def test_subscriptions_product_generic_one( @@ -902,7 +901,6 @@ def test_subscriptions_product_generic_one( test_client, product_type_1_subscriptions_factory, ): - expect_fail_test_if_too_many_duplicate_types_in_interface() # when subscriptions = product_type_1_subscriptions_factory(30) @@ -914,11 +912,12 @@ def test_subscriptions_product_generic_one( assert HTTPStatus.OK == response.status_code result = response.json() + expect_fail_test_if_too_many_duplicate_types_in_interface(result) + subscriptions_data = result["data"]["subscriptions"] subscriptions = subscriptions_data["page"] pageinfo = subscriptions_data["pageInfo"] - assert "errors" not in result assert len(subscriptions) == 1 assert pageinfo == { "hasPreviousPage": False, @@ -937,8 +936,6 @@ def test_single_subscription_product_list_union_type( test_client, product_sub_list_union_subscription_1, ): - expect_fail_test_if_too_many_duplicate_types_in_interface() - # when subscription_id = str(product_sub_list_union_subscription_1) @@ -949,11 +946,12 @@ def test_single_subscription_product_list_union_type( assert HTTPStatus.OK == response.status_code result = response.json() + expect_fail_test_if_too_many_duplicate_types_in_interface(result) + subscriptions_data = result["data"]["subscriptions"] subscriptions = subscriptions_data["page"] pageinfo = subscriptions_data["pageInfo"] - assert "errors" not in result assert len(subscriptions) == 1 assert pageinfo == { "hasPreviousPage": False, @@ -976,8 +974,6 @@ def test_single_subscription_product_list_union_type_provisioning_subscription( test_client, product_sub_list_union_subscription_1, ): - expect_fail_test_if_too_many_duplicate_types_in_interface() - # when subscription = SubscriptionModel.from_subscription(product_sub_list_union_subscription_1) @@ -992,11 +988,12 @@ def test_single_subscription_product_list_union_type_provisioning_subscription( assert HTTPStatus.OK == response.status_code result = response.json() + expect_fail_test_if_too_many_duplicate_types_in_interface(result) + subscriptions_data = result["data"]["subscriptions"] subscriptions = subscriptions_data["page"] pageinfo = subscriptions_data["pageInfo"] - assert "errors" not in result assert len(subscriptions) == 1 assert pageinfo == { "hasPreviousPage": False, @@ -1019,8 +1016,6 @@ def test_single_subscription_product_list_union_type_terminated_subscription( test_client, product_sub_list_union_subscription_1, ): - expect_fail_test_if_too_many_duplicate_types_in_interface() - # when subscription = SubscriptionModel.from_subscription(product_sub_list_union_subscription_1) @@ -1035,11 +1030,12 @@ def test_single_subscription_product_list_union_type_terminated_subscription( assert HTTPStatus.OK == response.status_code result = response.json() + expect_fail_test_if_too_many_duplicate_types_in_interface(result) + subscriptions_data = result["data"]["subscriptions"] subscriptions = subscriptions_data["page"] pageinfo = subscriptions_data["pageInfo"] - assert "errors" not in result assert len(subscriptions) == 1 assert pageinfo == { "hasPreviousPage": False,