Skip to content

Commit

Permalink
Fix TriggerDagRunOperator Tests for Database Isolation Tests (#41298)
Browse files Browse the repository at this point in the history
* Attempt to fix TriggerDagRunOperator for Database Isolation Tests

* Finalize making tests run for triggerdagrunoperator in db isolation mode

* Adjust query count assert for adjustments to serialization

* Review feedback

(cherry picked from commit 6b810b8)
  • Loading branch information
jscheffl authored and potiuk committed Aug 11, 2024
1 parent a7d48cb commit dc0548e
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 283 deletions.
8 changes: 8 additions & 0 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
import json
from typing import TYPE_CHECKING

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import DagNotFound, DagRunAlreadyExists
from airflow.models import DagBag, DagModel, DagRun
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from datetime import datetime

from sqlalchemy.orm.session import Session


def _trigger_dag(
dag_id: str,
Expand Down Expand Up @@ -103,12 +107,15 @@ def _trigger_dag(
return dag_runs


@internal_api_call
@provide_session
def trigger_dag(
dag_id: str,
run_id: str | None = None,
conf: dict | str | None = None,
execution_date: datetime | None = None,
replace_microseconds: bool = True,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
Triggers execution of DAG specified by dag_id.
Expand All @@ -118,6 +125,7 @@ def trigger_dag(
:param conf: configuration
:param execution_date: date of execution
:param replace_microseconds: whether microseconds should be zeroed
:param session: Unused. Only added in compatibility with database isolation mode
:return: first dag run triggered - even if more than one Dag Runs were triggered or None
"""
dag_model = DagModel.get_current(dag_id)
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

@functools.lru_cache
def initialize_method_map() -> dict[str, Callable]:
from airflow.api.common.trigger_dag import trigger_dag
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
Expand Down Expand Up @@ -92,6 +93,7 @@ def initialize_method_map() -> dict[str, Callable]:
_add_log,
_xcom_pull,
_record_task_map_for_downstreams,
trigger_dag,
DagCode.remove_deleted_code,
DagModel.deactivate_deleted_dags,
DagModel.get_paused_dag_ids,
Expand Down
22 changes: 22 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,28 @@ def __init__(self, dag_run: DagRun, execution_date: datetime.datetime, run_id: s
f"A DAG Run already exists for DAG {dag_run.dag_id} at {execution_date} with run id {run_id}"
)
self.dag_run = dag_run
self.execution_date = execution_date
self.run_id = run_id

def serialize(self):
cls = self.__class__
# Note the DagRun object will be detached here and fails serialization, we need to create a new one
from airflow.models import DagRun

dag_run = DagRun(
state=self.dag_run.state,
dag_id=self.dag_run.dag_id,
run_id=self.dag_run.run_id,
external_trigger=self.dag_run.external_trigger,
run_type=self.dag_run.run_type,
execution_date=self.dag_run.execution_date,
)
dag_run.id = self.dag_run.id
return (
f"{cls.__module__}.{cls.__name__}",
(),
{"dag_run": dag_run, "execution_date": self.execution_date, "run_id": self.run_id},
)


class DagFileExists(AirflowBadRequest):
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
TaskInstanceKey,
clear_task_instances,
)
from airflow.models.tasklog import LogTemplate
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.security import permissions
from airflow.settings import json
Expand Down Expand Up @@ -338,6 +339,9 @@ def _create_orm_dagrun(
creating_job_id=creating_job_id,
data_interval=data_interval,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
run.consumed_dataset_events = []
session.add(run)
session.flush()
run.dag = dag
Expand Down
14 changes: 14 additions & 0 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_internal.internal_api_call import InternalApiConfig
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -83,6 +84,8 @@ class TriggerDagRunOperator(BaseOperator):
"""
Triggers a DAG run for a specified DAG ID.
Note that if database isolation mode is enabled, not all features are supported.
:param trigger_dag_id: The ``dag_id`` of the DAG to trigger (templated).
:param trigger_run_id: The run ID to use for the triggered DAG run (templated).
If not provided, a run ID will be automatically generated.
Expand Down Expand Up @@ -174,6 +177,14 @@ def __init__(
self.logical_date = logical_date

def execute(self, context: Context):
if InternalApiConfig.get_use_internal_api():
if self.reset_dag_run:
raise AirflowException("Parameter reset_dag_run=True is broken with Database Isolation Mode.")
if self.wait_for_completion:
raise AirflowException(
"Parameter wait_for_completion=True is broken with Database Isolation Mode."
)

if isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
elif isinstance(self.logical_date, str):
Expand Down Expand Up @@ -210,6 +221,7 @@ def execute(self, context: Context):
if dag_model is None:
raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel")

# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
Expand Down Expand Up @@ -250,6 +262,7 @@ def execute(self, context: Context):
)
time.sleep(self.poke_interval)

# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run.refresh_from_db()
state = dag_run.state
if state in self.failed_states:
Expand All @@ -263,6 +276,7 @@ def execute_complete(self, context: Context, session: Session, event: tuple[str,
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
Expand Down
7 changes: 6 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,12 @@ def get_custom_dep() -> list[DagDependency]:

@classmethod
def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
if var is not None and op.has_dag() and attrname.endswith("_date"):
if (
var is not None
and op.has_dag()
and op.dag.__class__ is not AttributeRemoved
and attrname.endswith("_date")
):
# If this date is the same as the matching field in the dag, then
# don't store it again at the task level.
dag_date = getattr(op.dag, attrname, None)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3293,7 +3293,7 @@ def test_count_number_queries(self, tasks_count):
dag = DAG("test_dagrun_query_count", start_date=DEFAULT_DATE)
for i in range(tasks_count):
EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag)
with assert_queries_count(2):
with assert_queries_count(3):
dag.create_dagrun(
run_id="test_dagrun_query_count",
state=State.RUNNING,
Expand Down
Loading

0 comments on commit dc0548e

Please sign in to comment.