diff --git a/tests/steps/test_tensorflow_21_urllib3.py b/tests/steps/test_tensorflow_21_urllib3.py index 607a59902..020061995 100644 --- a/tests/steps/test_tensorflow_21_urllib3.py +++ b/tests/steps/test_tensorflow_21_urllib3.py @@ -20,9 +20,9 @@ import flexmock import pytest +from thoth.adviser.exceptions import NotAcceptable from thoth.adviser.steps import TensorFlow21Urllib3Step from thoth.adviser.state import State -from thoth.common import get_justification_link as jl from thoth.python import PackageVersion from thoth.python import Source @@ -53,17 +53,11 @@ def test_tf_21(self, urllib3_version: str, tf_version: str) -> None: context = flexmock() with TensorFlow21Urllib3Step.assigned_context(context): unit = TensorFlow21Urllib3Step() - assert unit.run(state, urllib3_package_version) == ( - -0.8, - [ - { - "message": "TensorFlow in version 2.1 can cause runtime errors when " - "imported, caused by incompatibility between urllib3 and six packages", - "type": "WARNING", - "link": jl("tf_21_urllib3"), - } - ], - ) + unit.pre_run() + assert unit._message_logged is False + with pytest.raises(NotAcceptable): + assert unit.run(state, urllib3_package_version) + assert unit._message_logged is True @pytest.mark.parametrize("urllib3_version,tf_version", [("1.2", "2.2.0"), ("1.25.10", "2.1")]) def test_no_tf_21(self, urllib3_version: str, tf_version: str) -> None: diff --git a/thoth/adviser/steps/tensorflow_21_urllib3.py b/thoth/adviser/steps/tensorflow_21_urllib3.py index ecd5434a3..7938fcbc6 100644 --- a/thoth/adviser/steps/tensorflow_21_urllib3.py +++ b/thoth/adviser/steps/tensorflow_21_urllib3.py @@ -17,6 +17,7 @@ """Suggest not to use TensorFlow 2.1 with specific urllib3 versions that cause six import issues.""" +import attr from typing import Any from typing import Optional from typing import Tuple @@ -29,6 +30,7 @@ from thoth.python import PackageVersion from ..enums import RecommendationType +from ..exceptions import NotAcceptable from ..state import State from ..step import Step @@ -40,26 +42,23 @@ _LOGGER = logging.getLogger(__name__) +@attr.s(slots=True) class TensorFlow21Urllib3Step(Step): """A step that suggests not to use TensorFlow 2.1 with specific urllib3 versions that cause six import issues.""" + _message_logged = attr.ib(type=bool, default=False, init=False) + # Run this step each time, regardless of when TensorFlow and urllib3 are resolved. MULTI_PACKAGE_RESOLUTIONS = True - _SCORE_ADDITION = -0.8 - _JUSTIFICATION_ADDITION = [ - { - "type": "WARNING", - "message": ( - "TensorFlow in version 2.1 can cause runtime errors when imported, caused by " - "incompatibility between urllib3 and six packages" - ), - "link": jl("tf_21_urllib3"), - } - ] - + _MESSAGE = f"TensorFlow in version 2.1 can cause runtime errors when imported, caused by " \ + f"incompatibility between urllib3 and six packages - see {jl('tf_21_urllib3')}" _AFFECTED_URLLIB3_VERSIONS = frozenset({(1, 2), (1, 3), (1, 4), (1, 5)}) + def pre_run(self) -> None: + """Initialize this pipeline unit before each run.""" + self._message_logged = False + @classmethod def should_include(cls, builder_context: "PipelineBuilderContext") -> Optional[Dict[str, Any]]: """Register this pipeline unit for adviser if not using latest recommendations.""" @@ -87,4 +86,8 @@ def run( if "tensorflow" not in state.resolved_dependencies or state.resolved_dependencies["tensorflow"][1][:3] != "2.1": return None - return self._SCORE_ADDITION, self._JUSTIFICATION_ADDITION + if not self._message_logged: + self._message_logged = True + _LOGGER.warning(self._MESSAGE) + + raise NotAcceptable