diff --git a/.circleci/docker-compose.cypress.yml b/.circleci/docker-compose.cypress.yml index 20cc9980e4..6781c3d5ed 100644 --- a/.circleci/docker-compose.cypress.yml +++ b/.circleci/docker-compose.cypress.yml @@ -31,7 +31,6 @@ services: REDASH_LOG_LEVEL: "INFO" REDASH_REDIS_URL: "redis://redis:6379/0" REDASH_DATABASE_URL: "postgresql://postgres@postgres/postgres" - QUEUES: "default periodic schemas" celery_worker: build: ../ command: celery_worker diff --git a/redash/app.py b/redash/app.py index 20bcf20927..1c592b3e70 100644 --- a/redash/app.py +++ b/redash/app.py @@ -23,7 +23,16 @@ def __init__(self, *args, **kwargs): def create_app(): - from . import authentication, extensions, handlers, limiter, mail, migrate, security + from . import ( + authentication, + extensions, + handlers, + limiter, + mail, + migrate, + security, + tasks, + ) from .handlers.webpack import configure_webpack from .metrics import request as request_metrics from .models import db, users @@ -47,5 +56,6 @@ def create_app(): configure_webpack(app) extensions.init_app(app) users.init_app(app) + tasks.init_app(app) return app diff --git a/redash/cli/rq.py b/redash/cli/rq.py index dab96f55c9..e6a2994c6c 100644 --- a/redash/cli/rq.py +++ b/redash/cli/rq.py @@ -5,11 +5,12 @@ from click import argument from flask.cli import AppGroup -from rq import Connection, Worker +from rq import Connection from sqlalchemy.orm import configure_mappers from redash import rq_redis_connection -from redash.schedule import ( +from redash.tasks import Worker +from redash.tasks.schedule import ( rq_scheduler, schedule_periodic_jobs, periodic_job_definitions, @@ -34,10 +35,10 @@ def worker(queues): configure_mappers() if not queues: - queues = ["periodic", "emails", "default", "schemas"] + queues = ["scheduled_queries", "queries", "periodic", "emails", "default", "schemas"] with Connection(rq_redis_connection): - w = Worker(queues, log_job_description=False) + w = Worker(queues, log_job_description=False, job_monitoring_interval=5) w.work() diff --git a/redash/handlers/query_results.py b/redash/handlers/query_results.py index 2056e7a5ff..404c6e0ff6 100644 --- a/redash/handlers/query_results.py +++ b/redash/handlers/query_results.py @@ -15,8 +15,9 @@ require_permission, require_any_of_permission, view_only, + view_only, ) -from redash.tasks import QueryTask +from redash.tasks import Job from redash.tasks.queries import enqueue_query from redash.utils import ( collect_parameters_from_request, @@ -35,6 +36,7 @@ serialize_query_result, serialize_query_result_to_dsv, serialize_query_result_to_xlsx, + serialize_job, ) @@ -119,7 +121,7 @@ def run_query(query, parameters, data_source, query_id, max_age=0): "Query ID": query_id, }, ) - return {"job": job.to_dict()} + return serialize_job(job) def get_download_filename(query_result, query, filetype): @@ -441,12 +443,12 @@ def get(self, job_id, query_id=None): """ Retrieve info about a running query job. """ - job = QueryTask(job_id=job_id) - return {"job": job.to_dict()} + job = Job.fetch(job_id) + return serialize_job(job) def delete(self, job_id): """ Cancel a query job in progress. """ - job = QueryTask(job_id=job_id) + job = Job.fetch(job_id) job.cancel() diff --git a/redash/monitor.py b/redash/monitor.py index 789903d724..3b8791683f 100644 --- a/redash/monitor.py +++ b/redash/monitor.py @@ -45,7 +45,7 @@ def get_queues_status(): }, **{ queue.name: {"size": len(queue)} - for queue in Queue.all(connection=rq_redis_connection) + for queue in Queue.all() }, } @@ -166,7 +166,7 @@ def rq_queues(): "started": fetch_jobs(q, StartedJobRegistry(queue=q).get_job_ids()), "queued": len(q.job_ids), } - for q in Queue.all(connection=rq_redis_connection) + for q in Queue.all() } @@ -189,7 +189,7 @@ def rq_workers(): "failed_jobs": w.failed_job_count, "total_working_time": w.total_working_time, } - for w in Worker.all(connection=rq_redis_connection) + for w in Worker.all() ] diff --git a/redash/serializers/__init__.py b/redash/serializers/__init__.py index b0c5fc8202..c358f5c84a 100644 --- a/redash/serializers/__init__.py +++ b/redash/serializers/__init__.py @@ -6,6 +6,8 @@ from funcy import project from flask_login import current_user +from rq.job import JobStatus +from rq.timeouts import JobTimeoutException from redash import models from redash.permissions import has_access, view_only @@ -263,3 +265,39 @@ def serialize_dashboard(obj, with_widgets=False, user=None, with_favorite_state= d["is_favorite"] = models.Favorite.is_favorite(current_user.id, obj) return d + + +def serialize_job(job): + # TODO: this is mapping to the old Job class statuses. Need to update the client side and remove this + STATUSES = { + JobStatus.QUEUED: 1, + JobStatus.STARTED: 2, + JobStatus.FINISHED: 3, + JobStatus.FAILED: 4, + } + + job_status = job.get_status() + if job.is_started: + updated_at = job.started_at or 0 + else: + updated_at = 0 + + status = STATUSES[job_status] + + if isinstance(job.result, Exception): + error = str(job.result) + status = 4 + elif job.is_cancelled: + error = "Query execution cancelled." + else: + error = "" + + return { + "job": { + "id": job.id, + "updated_at": updated_at, + "status": status, + "error": error, + "query_result_id": job.result if job.is_finished and not error else None, + } + } diff --git a/redash/tasks/__init__.py b/redash/tasks/__init__.py index e485c3ebf1..98df1d12ff 100644 --- a/redash/tasks/__init__.py +++ b/redash/tasks/__init__.py @@ -6,7 +6,6 @@ purge_failed_jobs, ) from .queries import ( - QueryTask, enqueue_query, execute_query, refresh_queries, @@ -16,3 +15,14 @@ ) from .alerts import check_alerts_for_query from .failure_report import send_aggregated_errors +from .worker import Worker, Queue, Job +from .schedule import rq_scheduler, schedule_periodic_jobs, periodic_job_definitions + +from redash import rq_redis_connection +from rq.connections import push_connection, pop_connection + + +def init_app(app): + app.before_request(lambda: push_connection(rq_redis_connection)) + app.teardown_request(lambda _: pop_connection()) + diff --git a/redash/tasks/queries/__init__.py b/redash/tasks/queries/__init__.py index 2eb3fe8529..7c5d34fb0f 100644 --- a/redash/tasks/queries/__init__.py +++ b/redash/tasks/queries/__init__.py @@ -4,4 +4,4 @@ cleanup_query_results, empty_schedules, ) -from .execution import QueryTask, execute_query, enqueue_query +from .execution import execute_query, enqueue_query diff --git a/redash/tasks/queries/execution.py b/redash/tasks/queries/execution.py index dd7f2c6ee6..a725bf5ae8 100644 --- a/redash/tasks/queries/execution.py +++ b/redash/tasks/queries/execution.py @@ -2,19 +2,21 @@ import signal import time import redis -from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded -from celery.result import AsyncResult -from celery.utils.log import get_task_logger from six import text_type +from rq import get_current_job +from rq.job import JobStatus +from rq.timeouts import JobTimeoutException + from redash import models, redis_connection, settings from redash.query_runner import InterruptException +from redash.tasks.worker import Queue, Job from redash.tasks.alerts import check_alerts_for_query from redash.tasks.failure_report import track_failure from redash.utils import gen_query_hash, json_dumps, utcnow -from redash.worker import celery +from redash.worker import celery, get_job_logger -logger = get_task_logger(__name__) +logger = get_job_logger(__name__) TIMEOUT_MESSAGE = "Query exceeded Redash query execution time limit." @@ -26,69 +28,6 @@ def _unlock(query_hash, data_source_id): redis_connection.delete(_job_lock_id(query_hash, data_source_id)) -class QueryTask(object): - # TODO: this is mapping to the old Job class statuses. Need to update the client side and remove this - STATUSES = {"PENDING": 1, "STARTED": 2, "SUCCESS": 3, "FAILURE": 4, "REVOKED": 4} - - def __init__(self, job_id=None, async_result=None): - if async_result: - self._async_result = async_result - else: - self._async_result = AsyncResult(job_id, app=celery) - - @property - def id(self): - return self._async_result.id - - def to_dict(self): - task_info = self._async_result._get_task_meta() - result, task_status = task_info["result"], task_info["status"] - if task_status == "STARTED": - updated_at = result.get("start_time", 0) - else: - updated_at = 0 - - status = self.STATUSES[task_status] - - if isinstance(result, (TimeLimitExceeded, SoftTimeLimitExceeded)): - error = TIMEOUT_MESSAGE - status = 4 - elif isinstance(result, Exception): - error = str(result) - status = 4 - elif task_status == "REVOKED": - error = "Query execution cancelled." - else: - error = "" - - if task_status == "SUCCESS" and not error: - query_result_id = result - else: - query_result_id = None - - return { - "id": self._async_result.id, - "updated_at": updated_at, - "status": status, - "error": error, - "query_result_id": query_result_id, - } - - @property - def is_cancelled(self): - return self._async_result.status == "REVOKED" - - @property - def celery_status(self): - return self._async_result.status - - def ready(self): - return self._async_result.ready() - - def cancel(self): - return self._async_result.revoke(terminate=True, signal="SIGINT") - - def enqueue_query( query, data_source, user_id, is_api_key=False, scheduled_query=None, metadata={} ): @@ -107,13 +46,14 @@ def enqueue_query( if job_id: logging.info("[%s] Found existing job: %s", query_hash, job_id) - job = QueryTask(job_id=job_id) + job = Job.fetch(job_id) - if job.ready(): + status = job.get_status() + if status in [JobStatus.FINISHED, JobStatus.FAILED]: logging.info( "[%s] job found is ready (%s), removing lock", query_hash, - job.celery_status, + status, ) redis_connection.delete(_job_lock_id(query_hash, data_source.id)) job = None @@ -128,37 +68,23 @@ def enqueue_query( queue_name = data_source.queue_name scheduled_query_id = None - args = ( - query, - data_source.id, - metadata, - user_id, - scheduled_query_id, - is_api_key, - ) - argsrepr = json_dumps( - { - "org_id": data_source.org_id, - "data_source_id": data_source.id, - "enqueue_time": time.time(), - "scheduled": scheduled_query_id is not None, - "query_id": metadata.get("Query ID"), - "user_id": user_id, - } - ) - time_limit = settings.dynamic_settings.query_time_limit( scheduled_query, user_id, data_source.org_id ) + metadata["Queue"] = queue_name - result = execute_query.apply_async( - args=args, - argsrepr=argsrepr, - queue=queue_name, - soft_time_limit=time_limit, + queue = Queue(queue_name) + job = queue.enqueue( + execute_query, + query, + data_source.id, + metadata, + user_id=user_id, + scheduled_query_id=scheduled_query_id, + is_api_key=is_api_key, + job_timeout=time_limit, ) - job = QueryTask(async_result=result) logging.info("[%s] Created new job: %s", query_hash, job.id) pipe.set( _job_lock_id(query_hash, data_source.id), @@ -201,20 +127,11 @@ def _resolve_user(user_id, is_api_key, query_id): return None -# We could have created this as a celery.Task derived class, and act as the task itself. But this might result in weird -# issues as the task class created once per process, so decided to have a plain object instead. class QueryExecutor(object): def __init__( - self, - task, - query, - data_source_id, - user_id, - is_api_key, - metadata, - scheduled_query, + self, query, data_source_id, user_id, is_api_key, metadata, scheduled_query ): - self.task = task + self.job = get_current_job() self.query = query self.data_source_id = data_source_id self.metadata = metadata @@ -242,7 +159,7 @@ def run(self): try: data, error = query_runner.run_query(annotated_query, self.user) except Exception as e: - if isinstance(e, SoftTimeLimitExceeded): + if isinstance(e, JobTimeoutException): error = TIMEOUT_MESSAGE else: error = text_type(e) @@ -253,7 +170,7 @@ def run(self): run_time = time.time() - started_at logger.info( - "task=execute_query query_hash=%s data_length=%s error=[%s]", + "job=execute_query query_hash=%s data_length=%s error=[%s]", self.query_hash, data and len(data), error, @@ -301,37 +218,34 @@ def run(self): return result def _annotate_query(self, query_runner): - self.metadata["Task ID"] = self.task.request.id + self.metadata["Job ID"] = self.job.id self.metadata["Query Hash"] = self.query_hash - self.metadata["Queue"] = self.task.request.delivery_info["routing_key"] self.metadata["Scheduled"] = self.scheduled_query is not None return query_runner.annotate_query(self.query, self.metadata) def _log_progress(self, state): logger.info( - "task=execute_query state=%s query_hash=%s type=%s ds_id=%d " - "task_id=%s queue=%s query_id=%s username=%s", + "job=execute_query state=%s query_hash=%s type=%s ds_id=%d " + "job_id=%s queue=%s query_id=%s username=%s", state, self.query_hash, self.data_source.type, self.data_source.id, - self.task.request.id, - self.task.request.delivery_info["routing_key"], + self.job.id, + self.metadata.get("Queue", "unknown"), self.metadata.get("Query ID", "unknown"), self.metadata.get("Username", "unknown"), ) def _load_data_source(self): - logger.info("task=execute_query state=load_ds ds_id=%d", self.data_source_id) + logger.info("job=execute_query state=load_ds ds_id=%d", self.data_source_id) return models.DataSource.query.get(self.data_source_id) # user_id is added last as a keyword argument for backward compatability -- to support executing previously submitted # jobs before the upgrade to this version. -@celery.task(name="redash.tasks.execute_query", bind=True, track_started=True) def execute_query( - self, query, data_source_id, metadata, @@ -344,6 +258,10 @@ def execute_query( else: scheduled_query = None - return QueryExecutor( - self, query, data_source_id, user_id, is_api_key, metadata, scheduled_query - ).run() + try: + return QueryExecutor( + query, data_source_id, user_id, is_api_key, metadata, scheduled_query + ).run() + except QueryExecutionError as e: + models.db.session.rollback() + return e diff --git a/redash/schedule.py b/redash/tasks/schedule.py similarity index 100% rename from redash/schedule.py rename to redash/tasks/schedule.py diff --git a/redash/tasks/worker.py b/redash/tasks/worker.py new file mode 100644 index 0000000000..9f6f93558c --- /dev/null +++ b/redash/tasks/worker.py @@ -0,0 +1,130 @@ +import errno +import os +import signal +import time +from rq import Worker as BaseWorker, Queue as BaseQueue, get_current_job +from rq.utils import utcnow +from rq.timeouts import UnixSignalDeathPenalty, HorseMonitorTimeoutException +from rq.job import Job as BaseJob, JobStatus + + +class CancellableJob(BaseJob): + def cancel(self, pipeline=None): + # TODO - add tests that verify that queued jobs are removed from queue and running jobs are actively cancelled + if self.is_started: + self.meta["cancelled"] = True + self.save_meta() + + super().cancel(pipeline=pipeline) + + @property + def is_cancelled(self): + return self.meta.get("cancelled", False) + + +class CancellableQueue(BaseQueue): + job_class = CancellableJob + + +class HardLimitingWorker(BaseWorker): + """ + RQ's work horses enforce time limits by setting a timed alarm and stopping jobs + when they reach their time limits. However, the work horse may be entirely blocked + and may not respond to the alarm interrupt. Since respecting timeouts is critical + in Redash (if we don't respect them, workers may be infinitely stuck and as a result, + service may be denied for other queries), we enforce two time limits: + 1. A soft time limit, enforced by the work horse + 2. A hard time limit, enforced by the parent worker + + The HardLimitingWorker class changes the default monitoring behavior of the default + RQ Worker by checking if the work horse is still busy with the job, even after + it should have timed out (+ a grace period of 15s). If it does, it kills the work horse. + """ + + grace_period = 15 + queue_class = CancellableQueue + job_class = CancellableJob + + def stop_executing_job(self, job): + os.kill(self.horse_pid, signal.SIGINT) + self.log.warning("Job %s has been cancelled.", job.id) + + def soft_limit_exceeded(self, job): + seconds_under_monitor = (utcnow() - self.monitor_started).seconds + return seconds_under_monitor > job.timeout + self.grace_period + + def enforce_hard_limit(self, job): + self.log.warning( + "Job %s exceeded timeout of %ds (+%ds grace period) but work horse did not terminate it. " + "Killing the work horse.", + job.id, + job.timeout, + self.grace_period, + ) + self.kill_horse() + + def monitor_work_horse(self, job): + """The worker will monitor the work horse and make sure that it + either executes successfully or the status of the job is set to + failed + """ + self.monitor_started = utcnow() + while True: + try: + with UnixSignalDeathPenalty( + self.job_monitoring_interval, HorseMonitorTimeoutException + ): + retpid, ret_val = os.waitpid(self._horse_pid, 0) + break + except HorseMonitorTimeoutException: + # Horse has not exited yet and is still running. + # Send a heartbeat to keep the worker alive. + self.heartbeat(self.job_monitoring_interval + 5) + + job.refresh() + + if job.is_cancelled: + self.stop_executing_job(job) + + if self.soft_limit_exceeded(job): + self.enforce_hard_limit(job) + except OSError as e: + # In case we encountered an OSError due to EINTR (which is + # caused by a SIGINT or SIGTERM signal during + # os.waitpid()), we simply ignore it and enter the next + # iteration of the loop, waiting for the child to end. In + # any other case, this is some other unexpected OS error, + # which we don't want to catch, so we re-raise those ones. + if e.errno != errno.EINTR: + raise + # Send a heartbeat to keep the worker alive. + self.heartbeat() + + if ret_val == os.EX_OK: # The process exited normally. + return + job_status = job.get_status() + if job_status is None: # Job completed and its ttl has expired + return + if job_status not in [JobStatus.FINISHED, JobStatus.FAILED]: + + if not job.ended_at: + job.ended_at = utcnow() + + # Unhandled failure: move the job to the failed queue + self.log.warning( + ( + "Moving job to FailedJobRegistry " + "(work-horse terminated unexpectedly; waitpid returned {})" + ).format(ret_val) + ) + + self.handle_job_failure( + job, + exc_string="Work-horse process was terminated unexpectedly " + "(waitpid returned %s)" % ret_val, + ) + + +Job = CancellableJob +Queue = CancellableQueue +Worker = HardLimitingWorker diff --git a/tests/tasks/test_queries.py b/tests/tasks/test_queries.py index a6de6dc485..24153bb91a 100644 --- a/tests/tasks/test_queries.py +++ b/tests/tasks/test_queries.py @@ -1,118 +1,128 @@ from unittest import TestCase -from collections import namedtuple import uuid -import datetime -import mock +from mock import patch, Mock + +from rq import Connection from tests import BaseTestCase -from redash import redis_connection, models -from redash.utils import json_dumps, utcnow +from redash import redis_connection, rq_redis_connection, models +from redash.utils import json_dumps from redash.query_runner.pg import PostgreSQL from redash.tasks.queries.execution import ( QueryExecutionError, enqueue_query, execute_query, ) +from redash.tasks import Job + + +def fetch_job(*args, **kwargs): + if any(args): + job_id = args[0] if isinstance(args[0], str) else args[0].id + else: + job_id = create_job().id + result = Mock() + result.id = job_id -FakeResult = namedtuple("FakeResult", "id") + return result -def gen_hash(*args, **kwargs): - return FakeResult(uuid.uuid4().hex) +def create_job(*args, **kwargs): + return Job(connection=rq_redis_connection) +@patch("redash.tasks.queries.execution.Job.fetch", side_effect=fetch_job) +@patch("redash.tasks.queries.execution.Queue.enqueue", side_effect=create_job) class TestEnqueueTask(BaseTestCase): - def test_multiple_enqueue_of_same_query(self): + def test_multiple_enqueue_of_same_query(self, enqueue, _): query = self.factory.create_query() - execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - - enqueue_query( - query.query_text, - query.data_source, - query.user_id, - False, - query, - {"Username": "Arik", "Query ID": query.id}, - ) - enqueue_query( - query.query_text, - query.data_source, - query.user_id, - False, - query, - {"Username": "Arik", "Query ID": query.id}, - ) - enqueue_query( - query.query_text, - query.data_source, - query.user_id, - False, - query, - {"Username": "Arik", "Query ID": query.id}, - ) - self.assertEqual(1, execute_query.apply_async.call_count) + with Connection(rq_redis_connection): + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) - @mock.patch("redash.settings.dynamic_settings.query_time_limit", return_value=60) - def test_limits_query_time(self, _): + self.assertEqual(1, enqueue.call_count) + + @patch("redash.settings.dynamic_settings.query_time_limit", return_value=60) + def test_limits_query_time(self, _, enqueue, __): query = self.factory.create_query() - execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - - enqueue_query( - query.query_text, - query.data_source, - query.user_id, - False, - query, - {"Username": "Arik", "Query ID": query.id}, - ) - _, kwargs = execute_query.apply_async.call_args - self.assertEqual(60, kwargs.get("soft_time_limit")) + with Connection(rq_redis_connection): + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) + + _, kwargs = enqueue.call_args + self.assertEqual(60, kwargs.get("job_timeout")) - def test_multiple_enqueue_of_different_query(self): + def test_multiple_enqueue_of_different_query(self, enqueue, _): query = self.factory.create_query() - execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - - enqueue_query( - query.query_text, - query.data_source, - query.user_id, - False, - None, - {"Username": "Arik", "Query ID": query.id}, - ) - enqueue_query( - query.query_text + "2", - query.data_source, - query.user_id, - False, - None, - {"Username": "Arik", "Query ID": query.id}, - ) - enqueue_query( - query.query_text + "3", - query.data_source, - query.user_id, - False, - None, - {"Username": "Arik", "Query ID": query.id}, - ) - self.assertEqual(3, execute_query.apply_async.call_count) + with Connection(rq_redis_connection): + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text + "2", + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text + "3", + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) + self.assertEqual(3, enqueue.call_count) + +@patch("redash.tasks.queries.execution.get_current_job", side_effect=fetch_job) class QueryExecutorTests(BaseTestCase): - def test_success(self): + def test_success(self, _): """ ``execute_query`` invokes the query runner and stores a query result. """ - cm = mock.patch( - "celery.app.task.Context.delivery_info", {"routing_key": "test"} - ) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: + with patch.object(PostgreSQL, "run_query") as qr: query_result_data = {"columns": [], "rows": []} qr.return_value = (json_dumps(query_result_data), None) result_id = execute_query("SELECT 1, 2", self.factory.data_source.id, {}) @@ -120,17 +130,14 @@ def test_success(self): result = models.QueryResult.query.get(result_id) self.assertEqual(result.data, query_result_data) - def test_success_scheduled(self): + def test_success_scheduled(self, _): """ Scheduled queries remember their latest results. """ - cm = mock.patch( - "celery.app.task.Context.delivery_info", {"routing_key": "test"} - ) q = self.factory.create_query( query_text="SELECT 1, 2", schedule={"interval": 300} ) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: + with patch.object(PostgreSQL, "run_query") as qr: qr.return_value = ([1, 2], None) result_id = execute_query( "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id @@ -140,87 +147,50 @@ def test_success_scheduled(self): result = models.QueryResult.query.get(result_id) self.assertEqual(q.latest_query_data, result) - def test_failure_scheduled(self): + def test_failure_scheduled(self, _): """ Scheduled queries that fail have their failure recorded. """ - cm = mock.patch( - "celery.app.task.Context.delivery_info", {"routing_key": "test"} - ) q = self.factory.create_query( query_text="SELECT 1, 2", schedule={"interval": 300} ) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: + with patch.object(PostgreSQL, "run_query") as qr: qr.side_effect = ValueError("broken") - with self.assertRaises(QueryExecutionError): - execute_query( - "SELECT 1, 2", - self.factory.data_source.id, - {}, - scheduled_query_id=q.id, - ) + + result = execute_query( + "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id + ) + self.assertTrue(isinstance(result, QueryExecutionError)) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 1) - with self.assertRaises(QueryExecutionError): - execute_query( - "SELECT 1, 2", - self.factory.data_source.id, - {}, - scheduled_query_id=q.id, - ) + + result = execute_query( + "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id + ) + self.assertTrue(isinstance(result, QueryExecutionError)) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 2) - def test_success_after_failure(self): + def test_success_after_failure(self, _): """ Query execution success resets the failure counter. """ - cm = mock.patch( - "celery.app.task.Context.delivery_info", {"routing_key": "test"} - ) q = self.factory.create_query( query_text="SELECT 1, 2", schedule={"interval": 300} ) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: + with patch.object(PostgreSQL, "run_query") as qr: qr.side_effect = ValueError("broken") - with self.assertRaises(QueryExecutionError): - execute_query( - "SELECT 1, 2", - self.factory.data_source.id, - {}, - scheduled_query_id=q.id, - ) + result = execute_query( + "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id + ) + self.assertTrue(isinstance(result, QueryExecutionError)) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 1) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: + with patch.object(PostgreSQL, "run_query") as qr: qr.return_value = ([1, 2], None) execute_query( "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 0) - - def test_doesnt_change_updated_at_timestamp(self): - cm = mock.patch("celery.app.task.Context.delivery_info", - {'routing_key': 'test'}) - - month_ago = utcnow() + datetime.timedelta(-30) - q = self.factory.create_query(query_text="SELECT 1, 2", schedule={"interval": 300}, updated_at=month_ago) - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: - qr.side_effect = ValueError("broken") - with self.assertRaises(QueryExecutionError): - execute_query("SELECT 1, 2", self.factory.data_source.id, {}, - scheduled_query_id=q.id) - q = models.Query.get_by_id(q.id) - self.assertEqual(q.schedule_failures, 1) - self.assertEqual(q.updated_at, month_ago) - - with cm, mock.patch.object(PostgreSQL, "run_query") as qr: - qr.return_value = ([1, 2], None) - execute_query("SELECT 1, 2", - self.factory.data_source.id, {}, - scheduled_query_id=q.id) - q = models.Query.get_by_id(q.id) - self.assertEqual(q.schedule_failures, 0) - self.assertEqual(q.updated_at, month_ago) diff --git a/tests/test_schedule.py b/tests/tasks/test_schedule.py similarity index 92% rename from tests/test_schedule.py rename to tests/tasks/test_schedule.py index 55af4e3dad..3e3329f62d 100644 --- a/tests/test_schedule.py +++ b/tests/tasks/test_schedule.py @@ -1,7 +1,7 @@ from unittest import TestCase from mock import patch, ANY -from redash.schedule import rq_scheduler, schedule_periodic_jobs +from redash.tasks.schedule import rq_scheduler, schedule_periodic_jobs class TestSchedule(TestCase): @@ -27,7 +27,7 @@ def foo(): pass schedule_periodic_jobs([{"func": foo, "interval": 60}]) - with patch("redash.schedule.rq_scheduler.schedule") as schedule: + with patch("redash.tasks.rq_scheduler.schedule") as schedule: schedule_periodic_jobs([{"func": foo, "interval": 60}]) schedule.assert_not_called() @@ -61,3 +61,4 @@ def bar(): self.assertEqual(len(jobs), 1) self.assertTrue(jobs[0].func_name.endswith("foo")) self.assertEqual(jobs[0].meta["interval"], 60) +