From dc0548e2a5a141d7f139488619e82cafef1102a3 Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Fri, 9 Aug 2024 20:54:19 +0200 Subject: [PATCH] Fix TriggerDagRunOperator Tests for Database Isolation Tests (#41298) * Attempt to fix TriggerDagRunOperator for Database Isolation Tests * Finalize making tests run for triggerdagrunoperator in db isolation mode * Adjust query count assert for adjustments to serialization * Review feedback (cherry picked from commit 6b810b89c3f63dd2d2cf107c568be40ba9da0ba2) --- airflow/api/common/trigger_dag.py | 8 + .../endpoints/rpc_api_endpoint.py | 2 + airflow/exceptions.py | 22 + airflow/models/dag.py | 4 + airflow/operators/trigger_dagrun.py | 14 + airflow/serialization/serialized_objects.py | 7 +- tests/models/test_dag.py | 2 +- tests/operators/test_trigger_dagrun.py | 675 ++++++++++-------- 8 files changed, 451 insertions(+), 283 deletions(-) diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py index 86513f78333c2..f22755ec640ea 100644 --- a/airflow/api/common/trigger_dag.py +++ b/airflow/api/common/trigger_dag.py @@ -22,15 +22,19 @@ import json from typing import TYPE_CHECKING +from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import DagNotFound, DagRunAlreadyExists from airflow.models import DagBag, DagModel, DagRun from airflow.utils import timezone +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType if TYPE_CHECKING: from datetime import datetime + from sqlalchemy.orm.session import Session + def _trigger_dag( dag_id: str, @@ -103,12 +107,15 @@ def _trigger_dag( return dag_runs +@internal_api_call +@provide_session def trigger_dag( dag_id: str, run_id: str | None = None, conf: dict | str | None = None, execution_date: datetime | None = None, replace_microseconds: bool = True, + session: Session = NEW_SESSION, ) -> DagRun | None: """ Triggers execution of DAG specified by dag_id. @@ -118,6 +125,7 @@ def trigger_dag( :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed + :param session: Unused. Only added in compatibility with database isolation mode :return: first dag run triggered - even if more than one Dag Runs were triggered or None """ dag_model = DagModel.get_current(dag_id) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index be4699fa6c7dd..ad65157ef9415 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -53,6 +53,7 @@ @functools.lru_cache def initialize_method_map() -> dict[str, Callable]: + from airflow.api.common.trigger_dag import trigger_dag from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor @@ -92,6 +93,7 @@ def initialize_method_map() -> dict[str, Callable]: _add_log, _xcom_pull, _record_task_map_for_downstreams, + trigger_dag, DagCode.remove_deleted_code, DagModel.deactivate_deleted_dags, DagModel.get_paused_dag_ids, diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 40a62ad20854c..3831d909fc272 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -239,6 +239,28 @@ def __init__(self, dag_run: DagRun, execution_date: datetime.datetime, run_id: s f"A DAG Run already exists for DAG {dag_run.dag_id} at {execution_date} with run id {run_id}" ) self.dag_run = dag_run + self.execution_date = execution_date + self.run_id = run_id + + def serialize(self): + cls = self.__class__ + # Note the DagRun object will be detached here and fails serialization, we need to create a new one + from airflow.models import DagRun + + dag_run = DagRun( + state=self.dag_run.state, + dag_id=self.dag_run.dag_id, + run_id=self.dag_run.run_id, + external_trigger=self.dag_run.external_trigger, + run_type=self.dag_run.run_type, + execution_date=self.dag_run.execution_date, + ) + dag_run.id = self.dag_run.id + return ( + f"{cls.__module__}.{cls.__name__}", + (), + {"dag_run": dag_run, "execution_date": self.execution_date, "run_id": self.run_id}, + ) class DagFileExists(AirflowBadRequest): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index c9380494a034d..1c9d351c1d292 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -115,6 +115,7 @@ TaskInstanceKey, clear_task_instances, ) +from airflow.models.tasklog import LogTemplate from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions from airflow.settings import json @@ -338,6 +339,9 @@ def _create_orm_dagrun( creating_job_id=creating_job_id, data_interval=data_interval, ) + # Load defaults into the following two fields to ensure result can be serialized detached + run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id)))) + run.consumed_dataset_events = [] session.add(run) session.flush() run.dag = dag diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 35d387738a0d3..2521297dcf936 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -27,6 +27,7 @@ from sqlalchemy.orm.exc import NoResultFound from airflow.api.common.trigger_dag import trigger_dag +from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -83,6 +84,8 @@ class TriggerDagRunOperator(BaseOperator): """ Triggers a DAG run for a specified DAG ID. + Note that if database isolation mode is enabled, not all features are supported. + :param trigger_dag_id: The ``dag_id`` of the DAG to trigger (templated). :param trigger_run_id: The run ID to use for the triggered DAG run (templated). If not provided, a run ID will be automatically generated. @@ -174,6 +177,14 @@ def __init__( self.logical_date = logical_date def execute(self, context: Context): + if InternalApiConfig.get_use_internal_api(): + if self.reset_dag_run: + raise AirflowException("Parameter reset_dag_run=True is broken with Database Isolation Mode.") + if self.wait_for_completion: + raise AirflowException( + "Parameter wait_for_completion=True is broken with Database Isolation Mode." + ) + if isinstance(self.logical_date, datetime.datetime): parsed_logical_date = self.logical_date elif isinstance(self.logical_date, str): @@ -210,6 +221,7 @@ def execute(self, context: Context): if dag_model is None: raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel") + # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date) @@ -250,6 +262,7 @@ def execute(self, context: Context): ) time.sleep(self.poke_interval) + # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_run.refresh_from_db() state = dag_run.state if state in self.failed_states: @@ -263,6 +276,7 @@ def execute_complete(self, context: Context, session: Session, event: tuple[str, # This logical_date is parsed from the return trigger event provided_logical_date = event[1]["execution_dates"][0] try: + # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_run = session.execute( select(DagRun).where( DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 94631c993c122..d110271c3da08 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1447,7 +1447,12 @@ def get_custom_dep() -> list[DagDependency]: @classmethod def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): - if var is not None and op.has_dag() and attrname.endswith("_date"): + if ( + var is not None + and op.has_dag() + and op.dag.__class__ is not AttributeRemoved + and attrname.endswith("_date") + ): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. dag_date = getattr(op.dag, attrname, None) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 376d5c5beb170..3d39a7290d909 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -3293,7 +3293,7 @@ def test_count_number_queries(self, tasks_count): dag = DAG("test_dagrun_query_count", start_date=DEFAULT_DATE) for i in range(tasks_count): EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag) - with assert_queries_count(2): + with assert_queries_count(3): dag.create_dagrun( run_id="test_dagrun_query_count", state=State.RUNNING, diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 341b34fe46fc6..349bba463800f 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import pathlib import tempfile from datetime import datetime from unittest import mock @@ -26,13 +25,14 @@ import pytest from airflow.exceptions import AirflowException, DagRunAlreadyExists, RemovedInAirflow3Warning, TaskDeferred -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.log import Log from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.settings import TracebackSessionForTests from airflow.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.session import create_session @@ -67,15 +67,18 @@ def setup_method(self): self._tmpfile = f.name f.write(DAG_SCRIPT) f.flush() + self.f_name = f.name with create_session() as session: session.add(DagModel(dag_id=TRIGGERED_DAG_ID, fileloc=self._tmpfile)) session.commit() - self.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) - dagbag = DagBag(f.name, read_dags_from_db=False, include_examples=False) - dagbag.bag_dag(self.dag, root_dag=self.dag) - dagbag.sync_to_db() + def re_sync_triggered_dag_to_db(self, dag, dag_maker): + TracebackSessionForTests.set_allow_db_access(dag_maker.session, True) + dagbag = DagBag(self.f_name, read_dags_from_db=False, include_examples=False) + dagbag.bag_dag(dag, root_dag=dag) + dagbag.sync_to_db(session=dag_maker.session) + TracebackSessionForTests.set_allow_db_access(dag_maker.session, False) def teardown_method(self): """Cleanup state after testing in DB.""" @@ -86,7 +89,7 @@ def teardown_method(self): synchronize_session=False ) - pathlib.Path(self._tmpfile).unlink() + # pathlib.Path(self._tmpfile).unlink() def assert_extra_link(self, triggered_dag_run, triggering_task, session): """ @@ -115,24 +118,32 @@ def assert_extra_link(self, triggered_dag_run, triggering_task, session): } assert expected_args in args - def test_trigger_dagrun(self): + def test_trigger_dagrun(self, dag_maker): """Test TriggerDagRunOperator.""" - task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, dag=self.dag) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - with create_session() as session: - dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() - assert dagrun.external_trigger - assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, dagrun.logical_date) - self.assert_extra_link(dagrun, task, session) + dagrun = dag_maker.session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.external_trigger + assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, dagrun.logical_date) + self.assert_extra_link(dagrun, task, dag_maker.session) - def test_trigger_dagrun_custom_run_id(self): - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id="custom_run_id", - dag=self.dag, - ) + def test_trigger_dagrun_custom_run_id(self, dag_maker): + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="custom_run_id", + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: @@ -140,15 +151,19 @@ def test_trigger_dagrun_custom_run_id(self): assert len(dagruns) == 1 assert dagruns[0].run_id == "custom_run_id" - def test_trigger_dagrun_with_logical_date(self): + def test_trigger_dagrun_with_logical_date(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date.""" custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=custom_logical_date, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=custom_logical_date, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: @@ -158,78 +173,91 @@ def test_trigger_dagrun_with_logical_date(self): assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, custom_logical_date) self.assert_extra_link(dagrun, task, session) - def test_trigger_dagrun_twice(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_twice(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date.""" utc_now = timezone.utcnow() - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=utc_now, - dag=self.dag, - poke_interval=1, - reset_dag_run=True, - wait_for_completion=True, - ) run_id = f"manual__{utc_now.isoformat()}" - with create_session() as session: - dag_run = DagRun( - dag_id=TRIGGERED_DAG_ID, - execution_date=utc_now, - state=State.SUCCESS, - run_type="manual", - run_id=run_id, + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id=run_id, + logical_date=utc_now, + poke_interval=1, + reset_dag_run=True, + wait_for_completion=True, ) - session.add(dag_run) - session.commit() - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() + dag_run = DagRun( + dag_id=TRIGGERED_DAG_ID, + execution_date=utc_now, + state=State.SUCCESS, + run_type="manual", + run_id=run_id, + ) + dag_maker.session.add(dag_run) + dag_maker.session.commit() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() - assert len(dagruns) == 1 - triggered_dag_run = dagruns[0] - assert triggered_dag_run.external_trigger - assert triggered_dag_run.logical_date == utc_now - self.assert_extra_link(triggered_dag_run, task, session) + dagruns = dag_maker.session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() + assert len(dagruns) == 1 + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.logical_date == utc_now + self.assert_extra_link(triggered_dag_run, task, dag_maker.session) - def test_trigger_dagrun_with_scheduled_dag_run(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date and scheduled dag_run.""" utc_now = timezone.utcnow() - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=utc_now, - dag=self.dag, - poke_interval=1, - reset_dag_run=True, - wait_for_completion=True, - ) - run_id = f"scheduled__{utc_now.isoformat()}" - with create_session() as session: - dag_run = DagRun( - dag_id=TRIGGERED_DAG_ID, - execution_date=utc_now, - state=State.SUCCESS, - run_type="scheduled", - run_id=run_id, + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=utc_now, + poke_interval=1, + reset_dag_run=True, + wait_for_completion=True, ) - session.add(dag_run) - session.commit() - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() + run_id = f"scheduled__{utc_now.isoformat()}" + dag_run = DagRun( + dag_id=TRIGGERED_DAG_ID, + execution_date=utc_now, + state=State.SUCCESS, + run_type="scheduled", + run_id=run_id, + ) + dag_maker.session.add(dag_run) + dag_maker.session.commit() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() - assert len(dagruns) == 1 - triggered_dag_run = dagruns[0] - assert triggered_dag_run.external_trigger - assert triggered_dag_run.logical_date == utc_now - self.assert_extra_link(triggered_dag_run, task, session) + dagruns = dag_maker.session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() + assert len(dagruns) == 1 + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.logical_date == utc_now + self.assert_extra_link(triggered_dag_run, task, dag_maker.session) - def test_trigger_dagrun_with_templated_logical_date(self): + def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): """Test TriggerDagRunOperator with templated logical_date.""" - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_str_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date="{{ logical_date }}", - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_str_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date="{{ logical_date }}", + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: @@ -240,14 +268,18 @@ def test_trigger_dagrun_with_templated_logical_date(self): assert triggered_dag_run.logical_date == DEFAULT_DATE self.assert_extra_link(triggered_dag_run, task, session) - def test_trigger_dagrun_operator_conf(self): + def test_trigger_dagrun_operator_conf(self, dag_maker): """Test passing conf to the triggered DagRun.""" - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_str_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - conf={"foo": "bar"}, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_str_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + conf={"foo": "bar"}, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: @@ -255,25 +287,33 @@ def test_trigger_dagrun_operator_conf(self): assert len(dagruns) == 1 assert dagruns[0].conf == {"foo": "bar"} - def test_trigger_dagrun_operator_templated_invalid_conf(self): + def test_trigger_dagrun_operator_templated_invalid_conf(self, dag_maker): """Test passing a conf that is not JSON Serializable raise error.""" - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_invalid_conf", - trigger_dag_id=TRIGGERED_DAG_ID, - conf={"foo": "{{ dag.dag_id }}", "datetime": timezone.utcnow()}, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_invalid_conf", + trigger_dag_id=TRIGGERED_DAG_ID, + conf={"foo": "{{ dag.dag_id }}", "datetime": timezone.utcnow()}, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() with pytest.raises(AirflowException, match="^conf parameter should be JSON Serializable$"): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - def test_trigger_dagrun_operator_templated_conf(self): + def test_trigger_dagrun_operator_templated_conf(self, dag_maker): """Test passing a templated conf to the triggered DagRun.""" - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_str_logical_date", - trigger_dag_id=TRIGGERED_DAG_ID, - conf={"foo": "{{ dag.dag_id }}"}, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_str_logical_date", + trigger_dag_id=TRIGGERED_DAG_ID, + conf={"foo": "{{ dag.dag_id }}"}, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: @@ -281,17 +321,21 @@ def test_trigger_dagrun_operator_templated_conf(self): assert len(dagruns) == 1 assert dagruns[0].conf == {"foo": TEST_DAG_ID} - def test_trigger_dagrun_with_reset_dag_run_false(self): + def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): """Test TriggerDagRunOperator without reset_dag_run.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id=None, - logical_date=None, - reset_dag_run=False, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id=None, + logical_date=None, + reset_dag_run=False, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -307,39 +351,50 @@ def test_trigger_dagrun_with_reset_dag_run_false(self): ("dummy_run_id", DEFAULT_DATE), ], ) - def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trigger_logical_date): + def test_trigger_dagrun_with_reset_dag_run_false_fail( + self, trigger_run_id, trigger_logical_date, dag_maker + ): """Test TriggerDagRunOperator without reset_dag_run but triggered dag fails.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id=trigger_run_id, - logical_date=trigger_logical_date, - reset_dag_run=False, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id=trigger_run_id, + logical_date=trigger_logical_date, + reset_dag_run=False, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) with pytest.raises(DagRunAlreadyExists): task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) - def test_trigger_dagrun_with_skip_when_already_exists(self): + def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): """Test TriggerDagRunOperator with skip_when_already_exists.""" execution_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id="dummy_run_id", - execution_date=None, - reset_dag_run=False, - skip_when_already_exists=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="dummy_run_id", + execution_date=None, + reset_dag_run=False, + skip_when_already_exists=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dr: DagRun = dag_maker.create_dagrun() task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) - assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS + assert dr.get_task_instance("test_task").state == TaskInstanceState.SUCCESS task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True) - assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED + assert dr.get_task_instance("test_task").state == TaskInstanceState.SKIPPED + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode @pytest.mark.parametrize( "trigger_run_id, trigger_logical_date, expected_dagruns_count", [ @@ -350,18 +405,22 @@ def test_trigger_dagrun_with_skip_when_already_exists(self): ], ) def test_trigger_dagrun_with_reset_dag_run_true( - self, trigger_run_id, trigger_logical_date, expected_dagruns_count + self, trigger_run_id, trigger_logical_date, expected_dagruns_count, dag_maker ): """Test TriggerDagRunOperator with reset_dag_run.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id=trigger_run_id, - logical_date=trigger_logical_date, - reset_dag_run=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id=trigger_run_id, + logical_date=trigger_logical_date, + reset_dag_run=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -370,106 +429,132 @@ def test_trigger_dagrun_with_reset_dag_run_true( assert len(dag_runs) == expected_dagruns_count assert dag_runs[0].external_trigger - def test_trigger_dagrun_with_wait_for_completion_true(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): """Test TriggerDagRunOperator with wait_for_completion.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - allowed_states=[State.QUEUED], - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + allowed_states=[State.QUEUED], + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - def test_trigger_dagrun_with_wait_for_completion_true_fail(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): """Test TriggerDagRunOperator with wait_for_completion but triggered dag fails.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - failed_states=[State.QUEUED], - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + failed_states=[State.QUEUED], + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() with pytest.raises(AirflowException): task.run(start_date=logical_date, end_date=logical_date) - def test_trigger_dagrun_triggering_itself(self): + def test_trigger_dagrun_triggering_itself(self, dag_maker): """Test TriggerDagRunOperator that triggers itself""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=self.dag.dag_id, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TEST_DAG_ID, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) - with create_session() as session: - dagruns = ( - session.query(DagRun) - .filter(DagRun.dag_id == self.dag.dag_id) - .order_by(DagRun.execution_date) - .all() - ) - assert len(dagruns) == 2 - triggered_dag_run = dagruns[1] - assert triggered_dag_run.state == State.QUEUED - self.assert_extra_link(triggered_dag_run, task, session) + dagruns = ( + dag_maker.session.query(DagRun) + .filter(DagRun.dag_id == TEST_DAG_ID) + .order_by(DagRun.execution_date) + .all() + ) + assert len(dagruns) == 2 + triggered_dag_run = dagruns[1] + assert triggered_dag_run.state == State.QUEUED - def test_trigger_dagrun_triggering_itself_with_logical_date(self): + def test_trigger_dagrun_triggering_itself_with_logical_date(self, dag_maker): """Test TriggerDagRunOperator that triggers itself with logical date, fails with DagRunAlreadyExists""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=self.dag.dag_id, - logical_date=logical_date, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TEST_DAG_ID, + logical_date=logical_date, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() with pytest.raises(DagRunAlreadyExists): task.run(start_date=logical_date, end_date=logical_date) - def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_maker): """Test TriggerDagRunOperator with wait_for_completion.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - allowed_states=[State.QUEUED], - deferrable=False, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + allowed_states=[State.QUEUED], + deferrable=False, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker): """Test TriggerDagRunOperator with wait_for_completion.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - allowed_states=[State.QUEUED], - deferrable=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + allowed_states=[State.QUEUED], + deferrable=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -485,19 +570,24 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self): task.execute_complete(context={}, event=trigger.serialize()) - def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, dag_maker): """Test TriggerDagRunOperator wait_for_completion dag run in non defined state.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - allowed_states=[State.SUCCESS], - deferrable=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + allowed_states=[State.SUCCESS], + deferrable=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -517,20 +607,25 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self): event=trigger.serialize(), ) - def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, dag_maker): """Test TriggerDagRunOperator wait_for_completion dag run in failed state.""" logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - allowed_states=[State.SUCCESS], - failed_states=[State.QUEUED], - deferrable=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + allowed_states=[State.SUCCESS], + failed_states=[State.QUEUED], + deferrable=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -548,19 +643,23 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self) with pytest.raises(AirflowException, match="failed with failed state"): task.execute_complete(context={}, event=trigger.serialize()) - def test_trigger_dagrun_with_execution_date(self): + def test_trigger_dagrun_with_execution_date(self, dag_maker): """Test TriggerDagRunOperator with custom execution_date (deprecated parameter)""" custom_execution_date = timezone.datetime(2021, 1, 2, 3, 4, 5) - with pytest.warns( - RemovedInAirflow3Warning, - match="Parameter 'execution_date' is deprecated. Use 'logical_date' instead.", - ): - task = TriggerDagRunOperator( - task_id="test_trigger_dagrun_with_execution_date", - trigger_dag_id=TRIGGERED_DAG_ID, - execution_date=custom_execution_date, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + with pytest.warns( + RemovedInAirflow3Warning, + match="Parameter 'execution_date' is deprecated. Use 'logical_date' instead.", + ): + task = TriggerDagRunOperator( + task_id="test_trigger_dagrun_with_execution_date", + trigger_dag_id=TRIGGERED_DAG_ID, + execution_date=custom_execution_date, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: @@ -570,6 +669,7 @@ def test_trigger_dagrun_with_execution_date(self): assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, custom_execution_date) self.assert_extra_link(dagrun, task, session) + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode @pytest.mark.parametrize( argnames=["trigger_logical_date"], argvalues=[ @@ -577,18 +677,22 @@ def test_trigger_dagrun_with_execution_date(self): pytest.param(None, id="logical_date=None"), ], ) - def test_dagstatetrigger_execution_dates(self, trigger_logical_date): + def test_dagstatetrigger_execution_dates(self, trigger_logical_date, dag_maker): """Ensure that the DagStateTrigger is called with the triggered DAG's logical date.""" - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=trigger_logical_date, - wait_for_completion=True, - poke_interval=5, - allowed_states=[DagRunState.QUEUED], - deferrable=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=trigger_logical_date, + wait_for_completion=True, + poke_interval=5, + allowed_states=[DagRunState.QUEUED], + deferrable=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) with mock.patch.object(TriggerDagRunOperator, "defer", mock_task_defer), pytest.raises(TaskDeferred): @@ -602,19 +706,24 @@ def test_dagstatetrigger_execution_dates(self, trigger_logical_date): pendulum.instance(dagruns[0].logical_date) ] - def test_dagstatetrigger_execution_dates_with_clear_and_reset(self): + @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode + def test_dagstatetrigger_execution_dates_with_clear_and_reset(self, dag_maker): """Check DagStateTrigger is called with the triggered DAG's logical date on subsequent defers.""" - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - trigger_run_id="custom_run_id", - wait_for_completion=True, - poke_interval=5, - allowed_states=[DagRunState.QUEUED], - deferrable=True, - reset_dag_run=True, - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="custom_run_id", + wait_for_completion=True, + poke_interval=5, + allowed_states=[DagRunState.QUEUED], + deferrable=True, + reset_dag_run=True, + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) with mock.patch.object(TriggerDagRunOperator, "defer", mock_task_defer), pytest.raises(TaskDeferred): @@ -647,16 +756,20 @@ def test_dagstatetrigger_execution_dates_with_clear_and_reset(self): pendulum.instance(triggered_logical_date) ] - def test_trigger_dagrun_with_no_failed_state(self): + def test_trigger_dagrun_with_no_failed_state(self, dag_maker): logical_date = DEFAULT_DATE - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TRIGGERED_DAG_ID, - logical_date=logical_date, - wait_for_completion=True, - poke_interval=10, - failed_states=[], - dag=self.dag, - ) + with dag_maker( + TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True + ) as dag: + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + logical_date=logical_date, + wait_for_completion=True, + poke_interval=10, + failed_states=[], + ) + self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.create_dagrun() assert task.failed_states == []