From 05bbe48439eb4093f09f8a690885c69fd8263b9a Mon Sep 17 00:00:00 2001 From: AndrewChubatiuk Date: Sat, 4 May 2024 14:37:17 +0300 Subject: [PATCH] pr comments --- poetry.lock | 26 ++--- pyproject.toml | 4 +- redash/app.py | 1 - redash/cli/groups.py | 7 +- redash/cli/users.py | 15 +-- redash/destinations/discord.py | 2 +- redash/handlers/alerts.py | 2 +- redash/handlers/organization.py | 33 +++--- redash/handlers/queries.py | 27 ++++- redash/handlers/users.py | 10 +- redash/handlers/visualizations.py | 6 +- redash/handlers/widgets.py | 2 +- redash/models/__init__.py | 139 +++++++++++++------------- redash/models/changes.py | 2 +- redash/models/organizations.py | 11 +- redash/models/users.py | 69 ++++++++----- redash/monitor.py | 10 +- redash/query_runner/databend.py | 10 +- redash/serializers/__init__.py | 14 +-- redash/tasks/alerts.py | 4 +- redash/tasks/queries/execution.py | 9 +- redash/tasks/queries/maintenance.py | 15 +-- tests/factories.py | 10 +- tests/handlers/test_alerts.py | 8 +- tests/handlers/test_dashboards.py | 2 +- tests/handlers/test_destinations.py | 2 +- tests/handlers/test_embed.py | 6 +- tests/handlers/test_visualizations.py | 3 +- tests/handlers/test_widgets.py | 6 +- tests/models/test_alerts.py | 8 +- tests/models/test_dashboards.py | 20 ++-- tests/models/test_data_sources.py | 3 +- tests/models/test_queries.py | 18 ++-- tests/tasks/test_alerts.py | 2 +- tests/test_authentication.py | 110 +++++++++----------- tests/test_cli.py | 31 +++--- tests/test_handlers.py | 3 +- tests/test_models.py | 10 +- 38 files changed, 338 insertions(+), 322 deletions(-) diff --git a/poetry.lock b/poetry.lock index 20f1b06d35..7bcb66a47f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -936,15 +936,16 @@ test-randomorder = ["pytest-randomly"] [[package]] name = "databend-driver" -version = "0.12.5" +version = "0.17.1" description = "Databend Driver Python Binding" optional = false -python-versions = ">=3.7" +python-versions = "<3.13,>=3.7" files = [ - {file = "databend_driver-0.12.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:91389252dd1a12e645fdb27a6d28203f5481b88b10187305ccc07d0e3e36f52c"}, - {file = "databend_driver-0.12.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:ad256ff5649c1d2fa45e5e541a552d5270a02b287555f627aacf4e288cf12d72"}, - {file = "databend_driver-0.12.5-cp37-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:c5479e95884a97487f0f916146a2ba5ac771ef28fb0862f81a7267c31d167812"}, - {file = "databend_driver-0.12.5-cp37-abi3-win_amd64.whl", hash = "sha256:da96dd8fa2bc67d8608921184923ffa485f7d84c503319552cc66c2e49f01f7e"}, + {file = "databend_driver-0.17.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:0e1618c37fe831f271dea7e4a61dbf1e942e9623fbeeb789afb03ad8fbb76e0d"}, + {file = "databend_driver-0.17.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:ee7842630267f882a77912b8cefa1fcdea57bd36ebc545e87fb741f05b1f7b84"}, + {file = "databend_driver-0.17.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3068248f3d83129375130ac932814e5a79a956e04681b2c853a37328fa34444"}, + {file = "databend_driver-0.17.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:a681ba8b0379272bb5015c94cd39554b8b6b79a5e20b42f998133b1d5981ec7b"}, + {file = "databend_driver-0.17.1-cp37-abi3-win_amd64.whl", hash = "sha256:74392c141ad2c15108086d7e46f877a9eb6974fd8f739aff59bfe7bf539cb1b6"}, ] [package.extras] @@ -954,21 +955,20 @@ test = ["behave"] [[package]] name = "databend-sqlalchemy" -version = "0.4.4" +version = "0.4.6" description = "Sqlalchemy adapter for Databend" optional = false python-versions = ">=3.7" files = [ - {file = "databend_sqlalchemy-0.4.4-py3-none-any.whl", hash = "sha256:190d85d64aa5207fd72bc9aa1720f465aa415e5aa836ea4805ca3b3dc62e75fd"}, - {file = "databend_sqlalchemy-0.4.4.tar.gz", hash = "sha256:114256b88f15588dc7ee4d9ef795e676c0ef9351718042ed181ee50158396531"}, + {file = "databend_sqlalchemy-0.4.6-py3-none-any.whl", hash = "sha256:39ca7d64ce1b2bce066ea1d15f2ef71a9f641588db5e09926cb14f7be4e5e519"}, ] [package.dependencies] -databend-driver = ">=0.12.1" -sqlalchemy = ">1.3.21" +databend-driver = ">=0.16.0" +sqlalchemy = ">=1.4" [package.extras] -dev = ["devtools (==0.7.0)", "mock (==4.0.3)", "pre-commit (==2.15.0)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "sqlalchemy-stubs (==0.4)"] +dev = ["devtools (==0.7.0)", "mock (==4.0.3)", "pre-commit (==2.15.0)", "pytest (==8.1.1)", "pytest-cov (==3.0.0)", "pytest-xdist (==3.5.0)", "sqlalchemy-stubs (==0.4)"] superset = ["apache-superset (>=1.4.1)"] [[package]] @@ -5255,4 +5255,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "1a38e5d879398123b962d610594c3696261367e010052666e6bfda98c9bab936" +content-hash = "2b98e4fa6b01454b78a76b46df8b12c6200b5310846c213f77e74450b6a1bb5f" diff --git a/pyproject.toml b/pyproject.toml index 57a7bcf6ee..4bbb824eda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,8 +97,8 @@ botocore = "1.31.8" cassandra-driver = "3.21.0" certifi = ">=2019.9.11" cmem-cmempy = "21.2.3" -databend-driver = "0.12.5" -databend-sqlalchemy = "0.4.4" +databend-driver = "0.17.1" +databend-sqlalchemy = "0.4.6" google-api-python-client = "1.7.11" gspread = "5.11.2" impyla = "0.16.0" diff --git a/redash/app.py b/redash/app.py index 0606dfbc30..c5aaaa6fea 100644 --- a/redash/app.py +++ b/redash/app.py @@ -28,7 +28,6 @@ def create_app(): from redash.metrics import request as request_metrics from redash.models import db, users from redash.utils import sentry - from redash.version_check import reset_new_version_status from . import ( limiter, diff --git a/redash/cli/groups.py b/redash/cli/groups.py index 06e4a7a805..c61c6a20ea 100644 --- a/redash/cli/groups.py +++ b/redash/cli/groups.py @@ -95,13 +95,12 @@ def extract_permissions_string(permissions): ) def list_command(organization=None): """List all groups""" + query_groups = select(models.Group) if organization: org = models.Organization.get_by_slug(organization) - qgroups = select(models.Group).where(models.Group.org == org) - else: - qgroups = select(models.Group) + query_groups = query_groups.where(models.Group.org == org) - groups = models.db.session.scalars(qgroups.order_by(models.Group.name)).all() + groups = models.db.session.scalars(query_groups.order_by(models.Group.name)).all() for i, group in enumerate(groups): if i > 0: print("-" * 20) diff --git a/redash/cli/users.py b/redash/cli/users.py index a0ae06cca4..24cd6fcef9 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -206,13 +206,9 @@ def delete(email, organization=None): if organization: org = models.Organization.get_by_slug(organization) deleted_users_query = deleted_users_query.where(models.User.org == org) - deleted_count = len( - models.db.session.scalars( - deleted_users_query.returning(models.User.id).execution_options(synchronize_session=False) - ).all() - ) + result = models.db.session.execute(deleted_users_query.execution_options(synchronize_session=False)) models.db.session.commit() - print("Deleted %d users." % deleted_count) + print("Deleted %d users." % result.rowcount) @manager.command() @@ -293,12 +289,11 @@ def invite(email, name, inviter_email, groups, is_admin=False, organization="def ) def list_command(organization=None): """List all users""" + query_users = select(models.User) if organization: org = models.Organization.get_by_slug(organization) - qusers = select(models.User).where(models.User.org == org) - else: - qusers = select(models.User) - users = models.db.session.scalars(qusers.order_by(models.User.name)).all() + query_users = query_users.where(models.User.org == org) + users = models.db.session.scalars(query_users.order_by(models.User.name)).all() for i, user in enumerate(users): if i > 0: diff --git a/redash/destinations/discord.py b/redash/destinations/discord.py index bd21e75c97..620b3e86a4 100644 --- a/redash/destinations/discord.py +++ b/redash/destinations/discord.py @@ -44,7 +44,7 @@ def notify(self, alert, query, user, new_state, app, host, metadata, options): ] if alert.custom_body: fields.append({"name": "Description", "value": alert.custom_body}) - if new_state == Alert.TRIGGERED_STATE: + if new_state == Alerts.TRIGGERED_STATE: if alert.options.get("custom_subject"): text = alert.options["custom_subject"] else: diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index a796aa789d..0bd67b039c 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -73,7 +73,7 @@ def post(self): alert = models.Alert( name=req["name"], - query_rel=query, + query=query, user=self.current_user, rearm=req.get("rearm"), options=req["options"], diff --git a/redash/handlers/organization.py b/redash/handlers/organization.py index e7d7aea7f7..a27bdadf44 100644 --- a/redash/handlers/organization.py +++ b/redash/handlers/organization.py @@ -1,6 +1,5 @@ from flask_login import current_user, login_required from sqlalchemy import distinct, func -from sqlalchemy.sql.expression import select from redash import models from redash.authentication import current_org @@ -12,34 +11,34 @@ @login_required def organization_status(org_slug=None): counters = { - "users": models.db.session.execute( - models.User.all(current_org, columns=[func.count(models.User.id)]) - ).first()[0], - "alerts": models.db.session.execute( - models.Alert.all( - group_ids=current_user.group_ids, columns=[func.count(models.Alert.id)], distinct=[] - ) - ).first()[0], - "data_sources": models.db.session.execute( + "users": models.db.session.scalar(models.User.all(current_org, columns=[func.count(models.User.id)])), + "alerts": models.db.session.scalar( + models.Alert.all(group_ids=current_user.group_ids, columns=[func.count(models.Alert.id)], distinct=[]) + ), + "data_sources": models.db.session.scalar( models.DataSource.all( current_org, group_ids=current_user.group_ids, columns=[func.count(models.DataSource.id)], ) - ).first()[0], - "queries": models.db.session.execute( + ), + "queries": models.db.session.scalar( models.Query.all( current_user.group_ids, user_id=current_user.id, include_drafts=True, columns=[func.count(distinct(models.Query.id))], ) - ).first()[0], - "dashboards": models.db.session.execute( - select(func.count(models.Dashboard.id)).where( - models.Dashboard.org == current_org, models.Dashboard.is_archived.is_(False) + ), + "dashboards": models.db.session.scalar( + models.Dashboard.all( + current_org, + [], + None, + columns=[func.count(distinct(models.Dashboard.id))], + distinct=[], ) - ).first()[0], + ), } return json_response(dict(object_counters=counters)) diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index a26b8a60bd..3afc86e515 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -150,7 +150,14 @@ def get(self): page = request.args.get("page", 1, type=int) per_page = request.args.get("per_page", 25, type=int) - results = paginate(ordered_results, page=page, per_page=per_page, serializer=QuerySerializer) + results = paginate( + ordered_results, + page=page, + per_page=per_page, + serializer=QuerySerializer, + with_stats=True, + with_last_modified_by=False, + ) if search_term: self.record_event({"action": "search", "object_type": "query", "term": search_term}) @@ -290,7 +297,14 @@ def get(self): page = request.args.get("page", 1, type=int) per_page = request.args.get("per_page", 25, type=int) - return paginate(ordered_results, page=page, per_page=per_page, serializer=QuerySerializer) + return paginate( + ordered_results, + page=page, + per_page=per_page, + serializer=QuerySerializer, + with_stats=True, + with_last_modified_by=False, + ) class QueryResource(BaseResource): @@ -483,7 +497,14 @@ def get(self): page = request.args.get("page", 1, type=int) per_page = request.args.get("per_page", 25, type=int) - results = paginate(ordered_favorites, page=page, per_page=per_page, serializer=QuerySerializer) + results = paginate( + ordered_favorites, + page=page, + per_page=per_page, + serializer=QuerySerializer, + with_stats=True, + with_last_modified_by=False, + ) self.record_event( { diff --git a/redash/handlers/users.py b/redash/handlers/users.py index d684532011..42373829ca 100644 --- a/redash/handlers/users.py +++ b/redash/handlers/users.py @@ -3,7 +3,7 @@ from flask_login import current_user, login_user from flask_restful import abort from funcy import partial, project -from sqlalchemy import asc, desc +from sqlalchemy import desc from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound @@ -32,13 +32,13 @@ # Ordering map for relationships order_map = { - "name": [asc(models.User.name)], + "name": [models.User.name], "-name": [desc(models.User.name)], - "active_at": [asc(models.User.active_at)], + "active_at": [models.User.active_at], "-active_at": [desc(models.User.active_at)], - "created_at": [asc(models.User.created_at)], + "created_at": [models.User.created_at], "-created_at": [desc(models.User.created_at)], - "groups": [asc(models.User.group_ids)], + "groups": [models.User.group_ids], "-groups": [desc(models.User.group_ids)], } diff --git a/redash/handlers/visualizations.py b/redash/handlers/visualizations.py index f29a1fb36c..e414211683 100644 --- a/redash/handlers/visualizations.py +++ b/redash/handlers/visualizations.py @@ -17,7 +17,7 @@ def post(self): query = get_object_or_404(models.Query.get_by_id_and_org, kwargs.pop("query_id"), self.current_org) require_object_modify_permission(query, self.current_user) - kwargs["query_rel"] = query + kwargs["query"] = query vis = models.Visualization(**kwargs) models.db.session.add(vis) @@ -29,7 +29,7 @@ class VisualizationResource(BaseResource): @require_permission("edit_query") def post(self, visualization_id): vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) - require_object_modify_permission(vis.query_rel, self.current_user) + require_object_modify_permission(vis.query, self.current_user) kwargs = request.get_json(force=True) @@ -44,7 +44,7 @@ def post(self, visualization_id): @require_permission("edit_query") def delete(self, visualization_id): vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) - require_object_modify_permission(vis.query_rel, self.current_user) + require_object_modify_permission(vis.query, self.current_user) self.record_event( { "action": "delete", diff --git a/redash/handlers/widgets.py b/redash/handlers/widgets.py index 051b6e386c..5480f93516 100644 --- a/redash/handlers/widgets.py +++ b/redash/handlers/widgets.py @@ -34,7 +34,7 @@ def post(self): visualization_id = widget_properties.pop("visualization_id") if visualization_id: visualization = models.Visualization.get_by_id_and_org(visualization_id, self.current_org) - require_access(visualization.query_rel, self.current_user, view_only) + require_access(visualization.query, self.current_user, view_only) else: visualization = None diff --git a/redash/models/__init__.py b/redash/models/__init__.py index fca9707cca..8253cfe7a2 100644 --- a/redash/models/__init__.py +++ b/redash/models/__init__.py @@ -13,7 +13,6 @@ from sqlalchemy.event import listens_for from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( - backref, contains_eager, joinedload, load_only, @@ -117,10 +116,16 @@ def get(self, query_id): class DataSource(BelongsToOrgMixin, db.Model): id = primary_key("DataSource") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, backref="data_sources") + org = db.relationship("Organization", back_populates="data_sources", uselist=False) name = Column(db.String(255)) type = Column(db.String(255)) + queries = db.relationship("Query", back_populates="data_source", lazy="noload") + query_results = db.relationship( + "QueryResult", + back_populates="data_source", + lazy="noload", + ) options = Column( "encrypted_options", ConfigurationContainer.as_mutable( @@ -318,9 +323,9 @@ class DataSourceGroup(db.Model): # XXX drop id, use datasource/group as PK id = primary_key("DataSourceGroup") data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id")) - data_source = db.relationship(DataSource, back_populates="data_source_groups") + data_source = db.relationship("DataSource", back_populates="data_source_groups", uselist=False) group_id = Column(key_type("Group"), db.ForeignKey("groups.id")) - group = db.relationship(Group, back_populates="data_sources") + group = db.relationship("Group", back_populates="data_sources", uselist=False) view_only = Column(db.Boolean, default=False) __tablename__ = "data_source_groups" @@ -331,9 +336,10 @@ class DataSourceGroup(db.Model): class QueryResult(db.Model, BelongsToOrgMixin): id = primary_key("QueryResult") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization) + org = db.relationship("Organization", back_populates="query_results", uselist=False) data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id")) - data_source = db.relationship(DataSource, backref=backref("query_results")) + data_source = db.relationship("DataSource", back_populates="query_results", uselist=False) + queries = db.relationship("Query", back_populates="latest_query_data", lazy="noload") query_hash = Column(db.String(32), index=True) query_text = Column("query", db.Text) data = Column(JSONText, nullable=True) @@ -387,6 +393,13 @@ def get_latest(cls, data_source, query, max_age=0): @classmethod def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): + queries = db.session.scalars( + select(Query).where( + Query.query_hash == query_hash, + Query.data_source == data_source, + Query.is_archived.is_(False), + ) + ).all() query_result = cls( org_id=org, query_hash=query_hash, @@ -395,6 +408,7 @@ def store_result(cls, org, data_source, query_hash, query, data, run_time, retri data_source_id=data_source.id, retrieved_at=retrieved_at, data=data, + queries=queries, ) db.session.add(query_result) @@ -463,20 +477,22 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = primary_key("Query") version = Column(db.Integer, default=1) org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, backref="queries") + org = db.relationship("Organization", back_populates="queries", uselist=False) data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id"), nullable=True) - data_source = db.relationship(DataSource, backref="queries") + data_source = db.relationship("DataSource", back_populates="queries", uselist=False) latest_query_data_id = Column(key_type("QueryResult"), db.ForeignKey("query_results.id"), nullable=True) - latest_query_data = db.relationship(QueryResult) + latest_query_data = db.relationship("QueryResult", back_populates="queries", uselist=False) name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) query_text = Column("query", db.Text) query_hash = Column(db.String(32)) api_key = Column(db.String(40), default=lambda: generate_token(40)) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User, foreign_keys=[user_id]) + user = db.relationship("User", foreign_keys="[Query.user_id]", uselist=False) last_modified_by_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True) - last_modified_by = db.relationship(User, backref="modified_queries", foreign_keys=[last_modified_by_id]) + last_modified_by = db.relationship( + "User", back_populates="modified_queries", foreign_keys="[Query.last_modified_by_id]", uselist=False + ) is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) schedule = Column(MutableDict.as_mutable(JSONB), nullable=True) @@ -484,6 +500,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): schedule_failures = Column(db.Integer, default=0) visualizations = db.relationship("Visualization", cascade="all, delete-orphan") options = Column(MutableDict.as_mutable(JSONB), default={}) + alerts = db.relationship("Alert", back_populates="query", lazy="noload") search_vector = Column( TSVectorType( "id", @@ -527,7 +544,7 @@ def create(cls, **kwargs): query = cls(**kwargs) db.session.add( Visualization( - query_rel=query, + query=query, name="Table", description="", type="TABLE", @@ -737,26 +754,6 @@ def all_groups_for_query_ids(cls, query_ids): return db.session.execute(text(query), {"ids": tuple(query_ids)}).all() - @classmethod - def update_latest_result(cls, query_result): - # TODO: Investigate how big an impact this select-before-update makes. - queries = db.session.scalars( - select(Query).where( - Query.query_hash == query_result.query_hash, - Query.data_source == query_result.data_source, - Query.is_archived.is_(False), - ) - ).all() - for q in queries: - q.latest_query_data = query_result - q.skip_updated_at = True - db.session.add(q) - query_ids = [q.id for q in queries] - - logging.info("Updated %s queries with result (%s).", len(query_ids), query_result.query_hash) - - return query_ids - def fork(self, user): forked_list = [ "org", @@ -775,7 +772,7 @@ def fork(self, user): for v in sorted(self.visualizations, key=lambda v: v.id): forked_v = v.copy() - forked_v["query_rel"] = forked_query + forked_v["query"] = forked_query fv = Visualization(**forked_v) # it will magically add it to `forked_query.visualizations` db.session.add(fv) @@ -868,19 +865,14 @@ class Favorite(TimestampMixin, db.Model): object = generic_relationship(object_type, object_id) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User, backref="favorites") + user = db.relationship("User", back_populates="favorites", uselist=False) __tablename__ = "favorites" __table_args__ = (UniqueConstraint("object_type", "object_id", "user_id", name="unique_favorite"),) @classmethod def is_favorite(cls, user, object): - return ( - db.session.execute( - select(func.count(cls.id)).where(cls.object == object, cls.user_id == user) - ).first()[0] - > 0 - ) + return db.session.scalar(select(func.count(cls.id)).where(cls.object == object, cls.user_id == user)) > 0 @classmethod def are_favorites(cls, user, objects): @@ -947,9 +939,9 @@ class Alert(TimestampMixin, BelongsToOrgMixin, db.Model): id = primary_key("Alert") name = Column(db.String(255)) query_id = Column(key_type("Query"), db.ForeignKey("queries.id")) - query_rel = db.relationship(Query, backref=backref("alerts", cascade="all")) + query = db.relationship("Query", back_populates="alerts", cascade="all", uselist=False) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User, backref="alerts") + user = db.relationship("User", back_populates="alerts", uselist=False) options = Column(MutableDict.as_mutable(JSONB), nullable=True) state = Column(db.String(255), default=Alerts.UNKNOWN_STATE) subscriptions = db.relationship("AlertSubscription", cascade="all, delete-orphan") @@ -977,7 +969,7 @@ def get_by_id_and_org(cls, object_id, org): return super(Alert, cls).get_by_id_and_org(object_id, org, Query) def evaluate(self): - data = self.query_rel.latest_query_data.data + data = self.query.latest_query_data.data if data["rows"] and self.options["column"] in data["rows"][0]: op = OPERATORS.get(self.options["op"], lambda v, t: False) @@ -998,8 +990,8 @@ def render_template(self, template): if template is None: return "" - data = self.query_rel.latest_query_data.data - host = base_url(self.query_rel.org) + data = self.query.latest_query_data.data + host = base_url(self.query.org) col_name = self.options["column"] if data["rows"] and col_name in data["rows"][0]: @@ -1017,8 +1009,8 @@ def render_template(self, template): "ALERT_STATUS": self.state.upper(), "ALERT_CONDITION": self.options["op"], "ALERT_THRESHOLD": self.options["value"], - "QUERY_NAME": self.query_rel.name, - "QUERY_URL": "{host}/queries/{query_id}".format(host=host, query_id=self.query_rel.id), + "QUERY_NAME": self.query.name, + "QUERY_URL": "{host}/queries/{query_id}".format(host=host, query_id=self.query.id), "QUERY_RESULT_VALUE": result_value, "QUERY_RESULT_ROWS": data["rows"], "QUERY_RESULT_COLS": data["columns"], @@ -1038,7 +1030,7 @@ def custom_subject(self): @property def groups(self): - return self.query_rel.groups + return self.query.groups @property def muted(self): @@ -1060,17 +1052,17 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model id = primary_key("Dashboard") version = Column(db.Integer) org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, backref="dashboards") + org = db.relationship("Organization", back_populates="dashboards", uselist=False) slug = Column(db.String(140), index=True, default=generate_slug) name = Column(db.String(100)) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User) + user = db.relationship("User", back_populates="dashboards", uselist=False) # layout is no longer used, but kept so we know how to render old dashboards. layout = Column(MutableList.as_mutable(JSONB), default=[]) dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) - widgets = db.relationship("Widget", backref="dashboard", lazy="dynamic") + widgets = db.relationship("Widget", back_populates="dashboard") tags = Column("tags", MutableList.as_mutable(ARRAY(db.Unicode)), nullable=True) options = Column(MutableDict.as_mutable(JSONB), default={}) @@ -1086,10 +1078,19 @@ def name_as_slug(self): @classmethod def all(cls, org, group_ids, user_id, columns=None, distinct=None): + conditions = [ + cls.is_archived.is_(False), + cls.org == org, + ] if columns is None: columns = [cls, User.id, User.name, User.details, User.email] if distinct is None: distinct = [func.lower(cls.name), cls.created_at, cls.slug] + if len(group_ids) > 0 or user_id is not None: + conditions = conditions + [ + or_(DataSourceGroup.group_id.in_(group_ids), cls.user_id == user_id), + or_(cls.user_id == user_id, cls.is_draft.is_(False)), + ] query = ( select(*columns) .join(User) @@ -1098,12 +1099,7 @@ def all(cls, org, group_ids, user_id, columns=None, distinct=None): .outerjoin(Visualization) .outerjoin(Query) .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) - .where( - cls.is_archived.is_(False), - cls.org == org, - or_(DataSourceGroup.group_id.in_(group_ids), cls.user_id == user_id), - or_(cls.user_id == user_id, cls.is_draft.is_(False)), - ) + .where(*conditions) ) return query @@ -1172,11 +1168,11 @@ class Visualization(TimestampMixin, BelongsToOrgMixin, db.Model): id = primary_key("Visualization") type = Column(db.String(100)) query_id = Column(key_type("Query"), db.ForeignKey("queries.id")) - # query_rel and not query, because db.Model already has query defined. - query_rel = db.relationship(Query, back_populates="visualizations") + query = db.relationship("Query", back_populates="visualizations", uselist=False) name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) options = Column(MutableDict.as_mutable(JSONB), nullable=True) + widgets = db.relationship("Widget", back_populates="visualization", cascade="all") __tablename__ = "visualizations" @@ -1200,11 +1196,12 @@ def copy(self): class Widget(TimestampMixin, BelongsToOrgMixin, db.Model): id = primary_key("Widget") visualization_id = Column(key_type("Visualization"), db.ForeignKey("visualizations.id"), nullable=True) - visualization = db.relationship(Visualization, backref=backref("widgets", cascade="delete")) + visualization = db.relationship("Visualization", back_populates="widgets", cascade="delete", uselist=False) text = Column(db.Text, nullable=True) width = Column(db.Integer) options = Column(MutableDict.as_mutable(JSONB), default={}) dashboard_id = Column(key_type("Dashboard"), db.ForeignKey("dashboards.id"), index=True) + dashboard = db.relationship("Dashboard", back_populates="widgets", uselist=False) __tablename__ = "widgets" @@ -1229,9 +1226,9 @@ def copy(self, dashboard_id): class Event(db.Model): id = primary_key("Event") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, back_populates="events") + org = db.relationship(Organization, back_populates="events", uselist=False) user_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True) - user = db.relationship(User, backref="events") + user = db.relationship(User, back_populates="events", uselist=False) action = Column(db.String(255)) object_type = Column(db.String(255)) object_id = Column(db.String(255), nullable=True) @@ -1286,13 +1283,13 @@ def record(cls, event): class ApiKey(TimestampMixin, GFKBase, db.Model): id = primary_key("ApiKey") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization) + org = db.relationship("Organization", uselist=False) api_key = Column(db.String(255), index=True, default=lambda: generate_token(40)) active = Column(db.Boolean, default=True) # 'object' provided by GFKBase object_id = Column(key_type("ApiKey")) created_by_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True) - created_by = db.relationship(User) + created_by = db.relationship("User", uselist=False) __tablename__ = "api_keys" __table_args__ = (db.Index("api_keys_object_type_object_id", "object_type", "object_id"),) @@ -1320,9 +1317,9 @@ def create_for_object(cls, object, user): class NotificationDestination(BelongsToOrgMixin, db.Model): id = primary_key("NotificationDestination") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, backref="notification_destinations") + org = db.relationship("Organization", back_populates="notification_destinations", uselist=False) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User, backref="notification_destinations") + user = db.relationship("User", back_populates="notification_destinations", uselist=False) name = Column(db.String(255)) type = Column(db.String(255)) options = Column( @@ -1373,13 +1370,13 @@ def notify(self, alert, query, user, new_state, app, host, metadata): class AlertSubscription(TimestampMixin, db.Model): id = primary_key("AlertSubscription") user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User) + user = db.relationship("User", back_populates="alert_subscriptions", uselist=False) destination_id = Column( key_type("NotificationDestination"), db.ForeignKey("notification_destinations.id"), nullable=True ) - destination = db.relationship(NotificationDestination) + destination = db.relationship("NotificationDestination", uselist=False) alert_id = Column(key_type("Alert"), db.ForeignKey("alerts.id")) - alert = db.relationship(Alert, back_populates="subscriptions") + alert = db.relationship("Alert", back_populates="subscriptions", uselist=False) __tablename__ = "alert_subscriptions" __table_args__ = ( @@ -1419,11 +1416,11 @@ def notify(self, alert, query, user, new_state, app, host, metadata): class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): id = primary_key("QuerySnippet") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship(Organization, backref="query_snippets") + org = db.relationship("Organization", back_populates="query_snippets", uselist=False) trigger = Column(db.String(255), unique=True) description = Column(db.Text) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship(User, backref="query_snippets") + user = db.relationship("User", back_populates="query_snippets", uselist=False) snippet = Column(db.Text) __tablename__ = "query_snippets" diff --git a/redash/models/changes.py b/redash/models/changes.py index c3ab7bcd6f..1c27363d51 100644 --- a/redash/models/changes.py +++ b/redash/models/changes.py @@ -13,7 +13,7 @@ class Change(GFKBase, db.Model): object_id = Column(key_type("Change")) object_version = Column(db.Integer, default=0) user_id = Column(key_type("User"), db.ForeignKey("users.id")) - user = db.relationship("User", backref="changes") + user = db.relationship("User", back_populates="changes") change = Column(JSONB) created_at = Column(db.DateTime(True), default=db.func.now()) diff --git a/redash/models/organizations.py b/redash/models/organizations.py index f9804632b2..7ecdae42ea 100644 --- a/redash/models/organizations.py +++ b/redash/models/organizations.py @@ -19,8 +19,15 @@ class Organization(TimestampMixin, db.Model): name = Column(db.String(255)) slug = Column(db.String(255), unique=True) settings = Column(MutableDict.as_mutable(JSONB), default={}) - groups = db.relationship("Group", lazy="dynamic") - events = db.relationship("Event", lazy="dynamic", order_by="desc(Event.created_at)") + queries = db.relationship("Query", back_populates="org") + groups = db.relationship("Group", back_populates="org") + events = db.relationship("Event", lazy="noload", order_by="desc(Event.created_at)") + notification_destinations = db.relationship("NotificationDestination", back_populates="org", lazy="noload") + query_snippets = db.relationship("QuerySnippet", back_populates="org", lazy="noload") + query_results = db.relationship("QueryResult", back_populates="org", lazy="noload") + data_sources = db.relationship("DataSource", back_populates="org", lazy="noload") + users = db.relationship("User", back_populates="org") + dashboards = db.relationship("Dashboard", back_populates="org") __tablename__ = "organizations" diff --git a/redash/models/users.py b/redash/models/users.py index 8b1b069e35..a647c5344a 100644 --- a/redash/models/users.py +++ b/redash/models/users.py @@ -8,6 +8,7 @@ from flask import current_app, request_started, url_for from flask_login import AnonymousUserMixin, UserMixin, current_user from passlib.apps import custom_app_context as pwd_context +from sqlalchemy import func from sqlalchemy.dialects.postgresql import ARRAY, JSONB from sqlalchemy.sql.expression import delete, select from sqlalchemy_utils import EmailType @@ -78,10 +79,27 @@ def has_permissions(self, permissions): class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): id = primary_key("User") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship("Organization", backref=db.backref("users", lazy="dynamic"), cascade_backrefs=False) + org = db.relationship("Organization", back_populates="users", uselist=False) name = Column(db.String(320)) email = Column(EmailType) password_hash = Column(db.String(128), nullable=True) + events = db.relationship("Event", back_populates="user", lazy="noload") + notification_destinations = db.relationship("NotificationDestination", back_populates="user", lazy="noload") + query_snippets = db.relationship("QuerySnippet", back_populates="user", lazy="noload") + favorites = db.relationship("Favorite", back_populates="user", lazy="noload") + alerts = db.relationship("Alert", back_populates="user", lazy="noload") + dashboards = db.relationship("Dashboard", back_populates="user", lazy="noload") + alert_subscriptions = db.relationship("AlertSubscription", back_populates="user", lazy="noload") + changes = db.relationship("Change", back_populates="user", lazy="noload") + modified_queries = db.relationship( + "Query", back_populates="last_modified_by", foreign_keys="[Query.last_modified_by_id]", lazy="noload" + ) + grantor = db.relationship( + "AccessPermission", back_populates="grantor", foreign_keys="[AccessPermission.grantor_id]", lazy="noload" + ) + grantee = db.relationship( + "AccessPermission", back_populates="grantee", foreign_keys="[AccessPermission.grantee_id]", lazy="noload" + ) group_ids = Column( "groups", MutableList.as_mutable(ARRAY(key_type("Group"))), @@ -209,10 +227,10 @@ def search(cls, base_query, term): @classmethod def pending(cls, base_query, pending): - clause = cls.is_invitation_pending.isnot(True) if pending: - clause = cls.is_invitation_pending.is_(True) - return base_query.where(clause) + return base_query.where(cls.is_invitation_pending.is_(True)) + else: + return base_query.where(cls.is_invitation_pending.isnot(True)) @classmethod def find_by_email(cls, email): @@ -266,7 +284,7 @@ class Group(db.Model, BelongsToOrgMixin): id = primary_key("Group") data_sources = db.relationship("DataSourceGroup", back_populates="group", cascade="all") org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) - org = db.relationship("Organization", back_populates="groups") + org = db.relationship("Organization", back_populates="groups", uselist=False) type = Column(db.String(255), default=REGULAR_GROUP) name = Column(db.String(100)) permissions = Column(ARRAY(db.String(255)), default=DEFAULT_PERMISSIONS) @@ -305,9 +323,13 @@ class AccessPermission(GFKBase, db.Model): # 'object' defined in GFKBase access_type = Column(db.String(255)) grantor_id = Column(key_type("User"), db.ForeignKey("users.id")) - grantor = db.relationship(User, backref="grantor", foreign_keys=[grantor_id]) + grantor = db.relationship( + "User", back_populates="grantor", foreign_keys="[AccessPermission.grantor_id]", uselist=False + ) grantee_id = Column(key_type("User"), db.ForeignKey("users.id")) - grantee = db.relationship(User, backref="grantee", foreign_keys=[grantee_id]) + grantee = db.relationship( + "User", back_populates="grantee", foreign_keys="[AccessPermission.grantee_id]", uselist=False + ) __tablename__ = "access_permissions" @@ -337,15 +359,9 @@ def grant(cls, obj, access_type, grantee, grantor): @classmethod def revoke(cls, obj, grantee, access_type=None): - q = delete(cls).where(cls.object_id == obj.id, cls.object_type == obj.__tablename__).returning(cls.id) - - if access_type: - q = q.where(AccessPermission.access_type == access_type) - - if grantee: - q = q.where(AccessPermission.grantee == grantee) - - return len(db.session.scalars(q).all()) + conditions = AccessPermission._query_condition(obj, access_type, grantee, None) + q = delete(cls).where(*conditions) + return db.session.execute(q).rowcount @classmethod def find(cls, obj, access_type=None, grantee=None, grantor=None): @@ -353,21 +369,24 @@ def find(cls, obj, access_type=None, grantee=None, grantor=None): @classmethod def exists(cls, obj, access_type, grantee): - return len(cls.find(obj, access_type, grantee)) > 0 + conditions = AccessPermission._query_condition(obj, access_type, grantee, None) + return db.session.scalar(select(func.count(cls.id)).where(*conditions)) > 0 @classmethod - def _query(cls, obj, access_type=None, grantee=None, grantor=None): - q = select(cls).where(cls.object_id == obj.id, cls.object_type == obj.__tablename__) - + def _query_condition(cls, obj, access_type=None, grantee=None, grantor=None): + conditions = [cls.object_id == obj.id, cls.object_type == obj.__tablename__] if access_type: - q = q.where(AccessPermission.access_type == access_type) - + conditions.append(AccessPermission.access_type == access_type) if grantee: - q = q.where(AccessPermission.grantee == grantee) - + conditions.append(AccessPermission.grantee == grantee) if grantor: - q = q.where(AccessPermission.grantor == grantor) + conditions.append(AccessPermission.grantor == grantor) + return conditions + @classmethod + def _query(cls, obj, access_type=None, grantee=None, grantor=None): + conditions = AccessPermission._query_condition(obj, access_type, grantee, grantor) + q = select(cls).where(*conditions) return db.session.scalars(q).all() def to_dict(self): diff --git a/redash/monitor.py b/redash/monitor.py index b509f17fcb..52146c241c 100644 --- a/redash/monitor.py +++ b/redash/monitor.py @@ -19,7 +19,7 @@ def get_redis_status(): def count(cls): - return db.session.execute(select(func.count(cls.id))).first()[0] + return db.session.scalar(select(func.count(cls.id))) def get_object_counts(): @@ -27,11 +27,9 @@ def get_object_counts(): status["queries_count"] = count(Query) if settings.FEATURE_SHOW_QUERY_RESULTS_COUNT: status["query_results_count"] = count(QueryResult) - status["unused_query_results_count"] = db.session.execute( - QueryResult.unused( - columns=[func.count(QueryResult.id)], days=settings.QUERY_RESULTS_CLEANUP_MAX_AGE - ) - ).first()[0] + status["unused_query_results_count"] = db.session.scalar( + QueryResult.unused(columns=[func.count(QueryResult.id)], days=settings.QUERY_RESULTS_CLEANUP_MAX_AGE) + ) status["dashboards_count"] = count(Dashboard) status["widgets_count"] = count(Widget) return status diff --git a/redash/query_runner/databend.py b/redash/query_runner/databend.py index c22a56af72..41c7bfc87a 100644 --- a/redash/query_runner/databend.py +++ b/redash/query_runner/databend.py @@ -1,7 +1,7 @@ try: import re - from databend_sqlalchemy import connector + from sqlalchemy import create_engine, text enabled = True except ImportError: @@ -73,12 +73,12 @@ def run_query(self, query, user): username = self.configuration.get("username") or "root" password = self.configuration.get("password") or "" database = self.configuration.get("database") or "default" - sslmode = self.configuration.get("sslmode") or False - connection = connector.connect(f"databend://{username}:{password}@{host}:{port}/{database}?sslmode={sslmode}") - cursor = connection.cursor() + sslmode = self.configuration.get("sslmode") or "disable" + engine = create_engine(f"databend://{username}:{password}@{host}:{port}/{database}?sslmode={sslmode}") + connection = engine.connect() try: - cursor.execute(query) + cursor = connection.execute(text(query)).cursor columns = self.fetch_columns([(i[0], self._define_column_type(i[1])) for i in cursor.description]) rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor] diff --git a/redash/serializers/__init__.py b/redash/serializers/__init__.py index ddaab00da7..1ffa7a85ff 100644 --- a/redash/serializers/__init__.py +++ b/redash/serializers/__init__.py @@ -41,10 +41,10 @@ def public_widget(widget): "updated_at": v.updated_at, "created_at": v.created_at, "query": { - "id": v.query_rel.id, - "name": v.query_rel.name, - "description": v.query_rel.description, - "options": v.query_rel.options, + "id": v.query.id, + "name": v.query.name, + "description": v.query.description, + "options": v.query.options, }, } @@ -155,7 +155,7 @@ def serialize_visualization(object, with_query=True): } if with_query: - d["query"] = serialize_query(object.query_rel) + d["query"] = serialize_query(object.query) return d @@ -190,7 +190,7 @@ def serialize_alert(alert, full=True): } if full: - d["query"] = serialize_query(alert.query_rel) + d["query"] = serialize_query(alert.query) d["user"] = alert.user.to_dict() else: d["query_id"] = alert.query_id @@ -208,7 +208,7 @@ def serialize_dashboard(obj, with_widgets=False, user=None, with_favorite_state= for w in obj.widgets: if w.visualization_id is None: widgets.append(serialize_widget(w)) - elif user and has_access(w.visualization.query_rel, user, view_only): + elif user and has_access(w.visualization.query, user, view_only): widgets.append(serialize_widget(w)) else: widget = project( diff --git a/redash/tasks/alerts.py b/redash/tasks/alerts.py index 7871afbe64..b6f96aa3d5 100644 --- a/redash/tasks/alerts.py +++ b/redash/tasks/alerts.py @@ -10,10 +10,10 @@ def notify_subscriptions(alert, new_state, metadata): - host = utils.base_url(alert.query_rel.org) + host = utils.base_url(alert.query.org) for subscription in alert.subscriptions: try: - subscription.notify(alert, alert.query_rel, subscription.user, new_state, current_app, host, metadata) + subscription.notify(alert, alert.query, subscription.user, new_state, current_app, host, metadata) except Exception: logger.exception("Error with processing destination") diff --git a/redash/tasks/queries/execution.py b/redash/tasks/queries/execution.py index 53db37fcaf..2fcb7b113d 100644 --- a/redash/tasks/queries/execution.py +++ b/redash/tasks/queries/execution.py @@ -204,10 +204,12 @@ def run(self): if error is not None and data is None: result = QueryExecutionError(error) if self.is_scheduled_query: + self.query_model = models.db.session.merge(self.query_model, load=False) track_failure(self.query_model, error) raise result else: if self.query_model and self.query_model.schedule_failures > 0: + self.query_model = models.db.session.merge(self.query_model, load=False) self.query_model.schedule_failures = 0 self.query_model.skip_updated_at = True models.db.session.add(self.query_model) @@ -221,14 +223,11 @@ def run(self): run_time, utcnow(), ) - models.db.session.commit() - - updated_query_ids = models.Query.update_latest_result(query_result) models.db.session.commit() # make sure that alert sees the latest query result self._log_progress("checking_alerts") - for query_id in updated_query_ids: - check_alerts_for_query.delay(query_id, self.metadata) + for q in query_result.queries: + check_alerts_for_query.delay(q.id, self.metadata) self._log_progress("finished") result = query_result.id diff --git a/redash/tasks/queries/maintenance.py b/redash/tasks/queries/maintenance.py index bc2eef2426..192ae48c94 100644 --- a/redash/tasks/queries/maintenance.py +++ b/redash/tasks/queries/maintenance.py @@ -132,16 +132,11 @@ def cleanup_query_results(): ) unused_query_results = models.QueryResult.unused(days=settings.QUERY_RESULTS_CLEANUP_MAX_AGE) - deleted_count = len( - models.db.session.scalars( - delete(models.QueryResult) - .where( - models.QueryResult.id.in_(unused_query_results.limit(settings.QUERY_RESULTS_CLEANUP_COUNT).subquery()) - ) - .execution_options(synchronize_session=False) - .returning(models.QueryResult.id) - ).all() - ) + deleted_count = models.db.session.execute( + delete(models.QueryResult) + .where(models.QueryResult.id.in_(unused_query_results.limit(settings.QUERY_RESULTS_CLEANUP_COUNT).subquery())) + .execution_options(synchronize_session=False) + ).rowcount models.db.session.commit() logger.info("Deleted %d unused query results.", deleted_count) diff --git a/tests/factories.py b/tests/factories.py index 676d998d08..fbdf8dd5c7 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -119,7 +119,7 @@ def __call__(self): alert_factory = ModelFactory( redash.models.Alert, name=Sequence("Alert {}"), - query_rel=query_factory.create, + query=query_factory.create, user=user_factory.create, options={}, ) @@ -138,7 +138,7 @@ def __call__(self): visualization_factory = ModelFactory( redash.models.Visualization, type="CHART", - query_rel=query_factory.create, + query=query_factory.create, name="Chart", description="", options={}, @@ -239,7 +239,7 @@ def create_group(self, **kwargs): return group_factory.create(**args) def create_alert(self, **kwargs): - args = {"user": self.user, "query_rel": self.create_query()} + args = {"user": self.user, "query": self.create_query()} args.update(kwargs) return alert_factory.create(**args) @@ -295,12 +295,12 @@ def create_query_result(self, **kwargs): return query_result_factory.create(**args) def create_visualization(self, **kwargs): - args = {"query_rel": self.create_query()} + args = {"query": self.create_query()} args.update(kwargs) return visualization_factory.create(**args) def create_visualization_with_params(self, **kwargs): - args = {"query_rel": self.create_query_with_params()} + args = {"query": self.create_query_with_params()} args.update(kwargs) return visualization_factory.create(**args) diff --git a/tests/handlers/test_alerts.py b/tests/handlers/test_alerts.py index 035ee12200..5c4884bbd7 100644 --- a/tests/handlers/test_alerts.py +++ b/tests/handlers/test_alerts.py @@ -12,7 +12,7 @@ def test_returns_200_if_allowed(self): def test_returns_403_if_not_allowed(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query_rel=query) + alert = self.factory.create_alert(query=query) db.session.commit() rv = self.make_request("get", "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 403) @@ -88,7 +88,7 @@ def test_returns_alerts_only_from_users_groups(self): query = self.factory.create_query( data_source=self.factory.create_data_source(group=self.factory.create_group()) ) - alert2 = self.factory.create_alert(query_rel=query) + alert2 = self.factory.create_alert(query=query) rv = self.make_request("get", "/api/alerts") self.assertEqual(rv.status_code, 200) @@ -151,7 +151,7 @@ def test_subscribers_user_to_alert(self): def test_doesnt_subscribers_user_to_alert_without_access(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query_rel=query) + alert = self.factory.create_alert(query=query) destination = self.factory.create_destination() rv = self.make_request( @@ -173,7 +173,7 @@ def test_returns_subscribers(self): def test_doesnt_return_subscribers_when_not_allowed(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query_rel=query) + alert = self.factory.create_alert(query=query) rv = self.make_request("get", "/api/alerts/{}/subscriptions".format(alert.id)) self.assertEqual(rv.status_code, 403) diff --git a/tests/handlers/test_dashboards.py b/tests/handlers/test_dashboards.py index b20614d731..2ba24fcae0 100644 --- a/tests/handlers/test_dashboards.py +++ b/tests/handlers/test_dashboards.py @@ -71,7 +71,7 @@ def test_get_dashboard_filters_unauthorized_widgets(self): restricted_ds = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=restricted_ds) - vis = self.factory.create_visualization(query_rel=query) + vis = self.factory.create_visualization(query=query) restricted_widget = self.factory.create_widget(visualization=vis, dashboard=dashboard) widget = self.factory.create_widget(dashboard=dashboard) dashboard.layout = [[widget.id, restricted_widget.id]] diff --git a/tests/handlers/test_destinations.py b/tests/handlers/test_destinations.py index 3a58fbda90..5a11fd2fbe 100644 --- a/tests/handlers/test_destinations.py +++ b/tests/handlers/test_destinations.py @@ -212,7 +212,7 @@ def test_slack_notify_calls_requests_post(): options = {"url": "https://slack.com/api/api.test"} metadata = {"Scheduled": False} - new_state = Alert.TRIGGERED_STATE + new_state = Alerts.TRIGGERED_STATE destination = Slack(options) with mock.patch("redash.destinations.slack.requests.post") as mock_post: diff --git a/tests/handlers/test_embed.py b/tests/handlers/test_embed.py index 59d3920d7d..55d15bd192 100644 --- a/tests/handlers/test_embed.py +++ b/tests/handlers/test_embed.py @@ -14,12 +14,12 @@ def test_not_embedable(self): class TestEmbedVisualization(BaseTestCase): def test_sucesss(self): vis = self.factory.create_visualization() - vis.query_rel.latest_query_data = self.factory.create_query_result() - db.session.add(vis.query_rel) + vis.query.latest_query_data = self.factory.create_query_result() + db.session.add(vis.query) res = self.make_request( "get", - "/embed/query/{}/visualization/{}".format(vis.query_rel.id, vis.id), + "/embed/query/{}/visualization/{}".format(vis.query.id, vis.id), is_json=False, ) self.assertEqual(res.status_code, 200) diff --git a/tests/handlers/test_visualizations.py b/tests/handlers/test_visualizations.py index 8410e50a28..af8e563195 100644 --- a/tests/handlers/test_visualizations.py +++ b/tests/handlers/test_visualizations.py @@ -1,3 +1,4 @@ +from sqlalchemy import func from sqlalchemy.sql.expression import select from redash import models @@ -28,7 +29,7 @@ def test_delete_visualization(self): rv = self.make_request("delete", "/api/visualizations/{}".format(visualization.id)) self.assertEqual(rv.status_code, 200) - self.assertEqual(models.Visualization.query.count(), 0) + self.assertEqual(models.db.session.scalar(select(func.count(models.Visualization.id))), 0) def test_update_visualization(self): visualization = self.factory.create_visualization() diff --git a/tests/handlers/test_widgets.py b/tests/handlers/test_widgets.py index 6756ecf606..a55a721e8f 100644 --- a/tests/handlers/test_widgets.py +++ b/tests/handlers/test_widgets.py @@ -26,9 +26,9 @@ def test_wont_create_widget_for_visualization_you_dont_have_access_to(self): dashboard = self.factory.create_dashboard() vis = self.factory.create_visualization() ds = self.factory.create_data_source(group=self.factory.create_group()) - vis.query_rel.data_source = ds + vis.query.data_source = ds - models.db.session.add(vis.query_rel) + models.db.session.add(vis.query) data = { "visualization_id": vis.id, @@ -63,4 +63,4 @@ def test_delete_widget(self): self.assertEqual(rv.status_code, 200) dashboard = models.Dashboard.get_by_slug_and_org(widget.dashboard.slug, widget.dashboard.org) - self.assertEqual(dashboard.widgets.count(), 0) + self.assertEqual(len(dashboard.widgets), 0) diff --git a/tests/models/test_alerts.py b/tests/models/test_alerts.py index 36ba372977..d45508868a 100644 --- a/tests/models/test_alerts.py +++ b/tests/models/test_alerts.py @@ -15,8 +15,8 @@ def test_returns_all_alerts_for_given_groups(self): query1 = self.factory.create_query(data_source=ds1) query2 = self.factory.create_query(data_source=ds2) - alert1 = self.factory.create_alert(query_rel=query1) - alert2 = self.factory.create_alert(query_rel=query2) + alert1 = self.factory.create_alert(query=query1) + alert2 = self.factory.create_alert(query=query2) db.session.flush() alerts = db.session.scalars(Alert.all(group_ids=[group.id, self.factory.default_group.id])).all() @@ -50,7 +50,7 @@ class TestAlertEvaluate(BaseTestCase): def create_alert(self, results, column="foo", value="1"): result = self.factory.create_query_result(data=results) query = self.factory.create_query(latest_query_data_id=result.id) - alert = self.factory.create_alert(query_rel=query, options={"op": "equals", "column": column, "value": value}) + alert = self.factory.create_alert(query=query, options={"op": "equals", "column": column, "value": value}) return alert def test_evaluate_triggers_alert_when_equal(self): @@ -95,7 +95,7 @@ class TestAlertRenderTemplate(BaseTestCase): def create_alert(self, results, column="foo", value="5"): result = self.factory.create_query_result(data=results) query = self.factory.create_query(latest_query_data_id=result.id) - alert = self.factory.create_alert(query_rel=query, options={"op": "equals", "column": column, "value": value}) + alert = self.factory.create_alert(query=query, options={"op": "equals", "column": column, "value": value}) return alert def test_render_custom_alert_template(self): diff --git a/tests/models/test_dashboards.py b/tests/models/test_dashboards.py index 75b1fa18c7..620b16ffde 100644 --- a/tests/models/test_dashboards.py +++ b/tests/models/test_dashboards.py @@ -1,3 +1,5 @@ +from sqlalchemy import distinct, func + from redash.models import Dashboard, db from tests import BaseTestCase @@ -9,9 +11,9 @@ def create_tagged_dashboard(self, tags): query = self.factory.create_query(data_source=ds) # We need a bunch of visualizations and widgets configured # to trigger wrong counts via the left outer joins. - vis1 = self.factory.create_visualization(query_rel=query) - vis2 = self.factory.create_visualization(query_rel=query) - vis3 = self.factory.create_visualization(query_rel=query) + vis1 = self.factory.create_visualization(query=query) + vis2 = self.factory.create_visualization(query=query) + vis3 = self.factory.create_visualization(query=query) widget1 = self.factory.create_widget(visualization=vis1, dashboard=dashboard) widget2 = self.factory.create_widget(visualization=vis2, dashboard=dashboard) widget3 = self.factory.create_widget(visualization=vis3, dashboard=dashboard) @@ -63,10 +65,10 @@ def test_returns_correct_number_of_dashboards(self): qry2 = self.factory.create_query(data_source=ds2, user=usr) viz1 = self.factory.create_visualization( - query_rel=qry1, + query=qry1, ) viz2 = self.factory.create_visualization( - query_rel=qry2, + query=qry2, ) def create_dashboard(): @@ -79,6 +81,10 @@ def create_dashboard(): create_dashboard() create_dashboard() - results = db.session.scalars(Dashboard.all(self.factory.org, usr.group_ids, usr.id)).all() + result = db.session.scalar( + Dashboard.all( + self.factory.org, usr.group_ids, usr.id, columns=[func.count(distinct(Dashboard.id))], distinct=[] + ) + ) - self.assertEqual(2, len(results), "The incorrect number of dashboards were returned") + self.assertEqual(2, result, "The incorrect number of dashboards were returned") diff --git a/tests/models/test_data_sources.py b/tests/models/test_data_sources.py index e0a072975b..de9f1c083e 100644 --- a/tests/models/test_data_sources.py +++ b/tests/models/test_data_sources.py @@ -1,5 +1,6 @@ import mock from mock import patch +from sqlalchemy import func from sqlalchemy.sql.expression import select from redash.models import DataSource, Query, QueryResult, db @@ -166,7 +167,7 @@ def test_deletes_child_models(self): data_source.delete() self.assertIsNone(db.session.get(DataSource, data_source.id)) self.assertEqual( - 0, len(db.session.scalars(select(QueryResult).where(QueryResult.data_source == data_source)).all()) + 0, db.session.scalar(select(func.count(QueryResult.id)).where(QueryResult.data_source == data_source)) ) @patch("redash.redis_connection.delete") diff --git a/tests/models/test_queries.py b/tests/models/test_queries.py index e79f44920d..00fc831a3f 100644 --- a/tests/models/test_queries.py +++ b/tests/models/test_queries.py @@ -366,8 +366,8 @@ def assert_visualizations(self, origin_q, origin_v, forked_q, forked_v): self.assertEqual(origin_v.options, forked_v.options) self.assertEqual(origin_v.type, forked_v.type) self.assertNotEqual(origin_v.id, forked_v.id) - self.assertNotEqual(origin_v.query_rel, forked_v.query_rel) - self.assertEqual(forked_q.id, forked_v.query_rel.id) + self.assertNotEqual(origin_v.query, forked_v.query) + self.assertEqual(forked_q.id, forked_v.query.id) def test_fork_with_visualizations(self): # prepare original query and visualizations @@ -375,10 +375,10 @@ def test_fork_with_visualizations(self): query = self.factory.create_query(data_source=data_source, description="this is description") # create default TABLE - query factory does not create it - self.factory.create_visualization(query_rel=query, name="Table", description="", type="TABLE", options={}) + self.factory.create_visualization(query=query, name="Table", description="", type="TABLE", options={}) visualization_chart = self.factory.create_visualization( - query_rel=query, + query=query, description="chart vis", type="CHART", options={ @@ -394,7 +394,7 @@ def test_fork_with_visualizations(self): }, ) visualization_box = self.factory.create_visualization( - query_rel=query, description="box vis", type="BOXPLOT", options={} + query=query, description="box vis", type="BOXPLOT", options={} ) fork_user = self.factory.create_user() forked_query = query.fork(fork_user) @@ -437,7 +437,7 @@ def test_fork_from_query_that_has_no_visualization(self): query = self.factory.create_query(data_source=data_source, description="this is description") # create default TABLE - query factory does not create it - self.factory.create_visualization(query_rel=query, name="Table", description="", type="TABLE", options={}) + self.factory.create_visualization(query=query, name="Table", description="", type="TABLE", options={}) fork_user = self.factory.create_user() @@ -488,8 +488,6 @@ def test_updates_existing_queries(self): db.session.commit() - Query.update_latest_result(query_result) - self.assertEqual(query1.latest_query_data, query_result) self.assertEqual(query2.latest_query_data, query_result) self.assertEqual(query3.latest_query_data, None) @@ -511,8 +509,6 @@ def test_doesnt_update_queries_with_different_hash(self): db.session.commit() - Query.update_latest_result(query_result) - self.assertEqual(query1.latest_query_data, query_result) self.assertEqual(query2.latest_query_data, query_result) self.assertNotEqual(query3.latest_query_data, query_result) @@ -534,8 +530,6 @@ def test_doesnt_update_queries_with_different_data_source(self): db.session.commit() - Query.update_latest_result(query_result) - self.assertEqual(query1.latest_query_data, query_result) self.assertEqual(query2.latest_query_data, query_result) self.assertNotEqual(query3.latest_query_data, query_result) diff --git a/tests/tasks/test_alerts.py b/tests/tasks/test_alerts.py index 1a736929e2..9a7abc074a 100644 --- a/tests/tasks/test_alerts.py +++ b/tests/tasks/test_alerts.py @@ -43,7 +43,7 @@ def test_calls_notify_for_subscribers(self): notify_subscriptions(subscription.alert, Alerts.OK_STATE, metadata={"Scheduled": False}) subscription.notify.assert_called_with( subscription.alert, - subscription.alert.query_rel, + subscription.alert.query, subscription.user, Alerts.OK_STATE, ANY, diff --git a/tests/test_authentication.py b/tests/test_authentication.py index a8b853864d..0b922648b0 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -12,15 +12,8 @@ from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql.expression import select +import redash.authentication as auth from redash import models, settings -from redash.authentication import ( - api_key_load_user_from_request, - get_login_url, - hmac_load_user_from_request, - jwt_auth, - org_settings, - sign, -) from redash.authentication.google_oauth import ( create_and_login_user, verify_profile, @@ -43,29 +36,29 @@ def setUp(self): def test_no_api_key(self): with self.app.test_client() as c: c.get(self.query_url) - self.assertIsNone(api_key_load_user_from_request(request)) + self.assertIsNone(auth.api_key_load_user_from_request(request)) def test_wrong_api_key(self): with self.app.test_client() as c: c.get(self.query_url, query_string={"api_key": "whatever"}) - self.assertIsNone(api_key_load_user_from_request(request)) + self.assertIsNone(auth.api_key_load_user_from_request(request)) def test_correct_api_key(self): with self.app.test_client() as c: c.get(self.query_url, query_string={"api_key": self.api_key}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + self.assertIsNotNone(auth.api_key_load_user_from_request(request)) def test_no_query_id(self): with self.app.test_client() as c: c.get(self.queries_url, query_string={"api_key": self.api_key}) - self.assertIsNone(api_key_load_user_from_request(request)) + self.assertIsNone(auth.api_key_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") models.db.session.flush() with self.app.test_client() as c: c.get(self.queries_url, query_string={"api_key": user.api_key}) - self.assertEqual(user.id, api_key_load_user_from_request(request).id) + self.assertEqual(user.id, auth.api_key_load_user_from_request(request).id) def test_disabled_user_api_key(self): user = self.factory.create_user(api_key="user_key") @@ -73,17 +66,17 @@ def test_disabled_user_api_key(self): models.db.session.flush() with self.app.test_client() as c: c.get(self.queries_url, query_string={"api_key": user.api_key}) - self.assertEqual(None, api_key_load_user_from_request(request)) + self.assertEqual(None, auth.api_key_load_user_from_request(request)) def test_api_key_header(self): with self.app.test_client() as c: c.get(self.query_url, headers={"Authorization": "Key {}".format(self.api_key)}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + self.assertIsNotNone(auth.api_key_load_user_from_request(request)) def test_api_key_header_with_wrong_key(self): with self.app.test_client() as c: c.get(self.query_url, headers={"Authorization": "Key oops"}) - self.assertIsNone(api_key_load_user_from_request(request)) + self.assertIsNone(auth.api_key_load_user_from_request(request)) def test_api_key_for_wrong_org(self): other_user = self.factory.create_admin(org=self.factory.create_org()) @@ -109,12 +102,12 @@ def setUp(self): self.expires = time.time() + 1800 def signature(self, expires): - return sign(self.query.api_key, self.path, expires) + return auth.sign(self.query.api_key, self.path, expires) def test_no_signature(self): with self.app.test_client() as c: c.get(self.path) - self.assertIsNone(hmac_load_user_from_request(request)) + self.assertIsNone(auth.hmac_load_user_from_request(request)) def test_wrong_signature(self): with self.app.test_client() as c: @@ -122,7 +115,7 @@ def test_wrong_signature(self): self.path, query_string={"signature": "whatever", "expires": self.expires}, ) - self.assertIsNone(hmac_load_user_from_request(request)) + self.assertIsNone(auth.hmac_load_user_from_request(request)) def test_correct_signature(self): with self.app.test_client() as c: @@ -133,7 +126,7 @@ def test_correct_signature(self): "expires": self.expires, }, ) - self.assertIsNotNone(hmac_load_user_from_request(request)) + self.assertIsNotNone(auth.hmac_load_user_from_request(request)) def test_no_query_id(self): with self.app.test_client() as c: @@ -141,14 +134,14 @@ def test_no_query_id(self): "/{}/api/queries".format(self.query.org.slug), query_string={"api_key": self.api_key}, ) - self.assertIsNone(hmac_load_user_from_request(request)) + self.assertIsNone(auth.hmac_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") path = "/api/queries/" models.db.session.flush() - signature = sign(user.api_key, path, self.expires) + signature = auth.sign(user.api_key, path, self.expires) with self.app.test_client() as c: c.get( path, @@ -158,7 +151,7 @@ def test_user_api_key(self): "user_id": user.id, }, ) - self.assertEqual(user.id, hmac_load_user_from_request(request).id) + self.assertEqual(user.id, auth.hmac_load_user_from_request(request).id) class TestSessionAuthentication(BaseTestCase): @@ -242,11 +235,11 @@ def test_user_not_in_domain_but_account_exists(self): class TestGetLoginUrl(BaseTestCase): def test_when_multi_org_enabled_and_org_exists(self): with self.app.test_request_context("/{}/".format(self.factory.org.slug)): - self.assertEqual(get_login_url(next=None), "/{}/login".format(self.factory.org.slug)) + self.assertEqual(auth.get_login_url(next=None), "/{}/login".format(self.factory.org.slug)) def test_when_multi_org_enabled_and_org_doesnt_exist(self): with self.app.test_request_context("/{}_notexists/".format(self.factory.org.slug)): - self.assertEqual(get_login_url(next=None), "/") + self.assertEqual(auth.get_login_url(next=None), "/") class TestRedirectToUrlAfterLoggingIn(BaseTestCase): @@ -415,67 +408,58 @@ def test_disabled_user_should_not_receive_password_reset_link(self): send_user_disabled_email_mock.assert_called_with(user) +rsa_private_key = "/tmp/jwtRS256.key" +rsa_public_key = "/tmp/jwtRS256.pem" +org_settings = dict(auth.org_settings) +org_settings.update( + { + "auth_jwt_login_enabled": True, + "auth_jwt_auth_public_certs_url": "file:///{}".format(rsa_public_key), + "auth_jwt_auth_issuer": "Admin", + "auth_jwt_auth_audience": "My Org", + "auth_jwt_auth_header_name": "jwt-token", + } +) + + +@patch("redash.authentication.org_settings", org_settings) class TestJWTAuthentication(BaseTestCase): def setUp(self): super(TestJWTAuthentication, self).setUp() - self.auth_audience = "My Org" - self.auth_issuer = "Admin" - self.token_name = "jwt-token" - self.rsa_private_key = "/tmp/jwtRS256.key" - self.rsa_public_key = "/tmp/jwtRS256.pem" - - if not os.path.exists(self.rsa_public_key): - subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"]) - subprocess.check_output( - ["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key] - ) - - org_settings["auth_jwt_login_enabled"] = True - org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key) - org_settings["auth_jwt_auth_issuer"] = self.auth_issuer - org_settings["auth_jwt_auth_audience"] = self.auth_audience - org_settings["auth_jwt_auth_header_name"] = self.token_name - - def tearDown(self): - org_settings["auth_jwt_login_enabled"] = False - org_settings["auth_jwt_auth_public_certs_url"] = "" - org_settings["auth_jwt_auth_issuer"] = "" - org_settings["auth_jwt_auth_audience"] = "" - org_settings["auth_jwt_auth_header_name"] = "" + if not os.path.exists(rsa_public_key): + subprocess.check_output(["openssl", "genrsa", "-out", rsa_private_key, "4096"]) + subprocess.check_output(["openssl", "rsa", "-pubout", "-in", rsa_private_key, "-out", rsa_public_key]) - def jwt_no_token(self): + def test_jwt_no_token(self): response = self.get_request("/data_sources", org=self.factory.org) self.assertEqual(response.status_code, 302) - def jwt_from_pem_file(self): + def test_jwt_from_pem_file(self): user = self.factory.create_user() - issued_at_timestamp = time.time() expiration_timestamp = issued_at_timestamp + 60 - data = { - "aud": self.auth_audience, + "aud": org_settings["auth_jwt_auth_audience"], "email": user.email, "exp": expiration_timestamp, "iat": issued_at_timestamp, - "iss": self.auth_issuer, + "iss": org_settings["auth_jwt_auth_issuer"], } - with open(self.rsa_private_key) as keyfile: + with open(rsa_private_key) as keyfile: sign_key = keyfile.read().strip() token_data = jwt.encode(data, sign_key, algorithm="RS256") - - response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data}) + response = self.get_request( + "/data_sources", org=self.factory.org, headers={org_settings["auth_jwt_auth_header_name"]: token_data} + ) self.assertEqual(response.status_code, 200) @patch.object(requests, "get") - def jwk_decode(self, mock_get): - with open(self.rsa_public_key, "rb") as keyfile: + def test_jwk_decode(self, mock_get): + with open(rsa_public_key, "rb") as keyfile: public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read()) jwk_keys = {"keys": [json.loads(public_key.export())]} - mockresponse = Mock() mockresponse.json = lambda: jwk_keys mock_get.return_value = mockresponse - - keys = jwt_auth.get_public_keys("http://localhost/key.jwt") + keys = auth.jwt_auth.get_public_keys("http://localhost/key.jwt") self.assertEqual(keys[0].key_size, 4096) diff --git a/tests/test_cli.py b/tests/test_cli.py index 77f01e2803..830685b563 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import mock from click.testing import CliRunner +from sqlalchemy import func from sqlalchemy.sql.expression import select from redash.cli import manager @@ -22,7 +23,7 @@ def test_interactive_new(self): ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 1) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 1) ds = db.session.scalar(select(DataSource)) self.assertEqual(ds.name, "test") self.assertEqual(ds.type, "pg") @@ -44,7 +45,7 @@ def test_options_new(self): ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 1) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 1) ds = db.session.scalar(select(DataSource)) self.assertEqual(ds.name, "test") self.assertEqual(ds.type, "pg") @@ -57,7 +58,7 @@ def test_bad_type_new(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("not supported", result.output) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 0) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 0) def test_bad_options_new(self): runner = CliRunner() @@ -76,7 +77,7 @@ def test_bad_options_new(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("invalid configuration", result.output) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 0) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 0) def test_list(self): self.factory.create_data_source( @@ -152,7 +153,7 @@ def test_connection_delete(self): self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertIn("Deleting", result.output) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 0) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 0) def test_connection_bad_delete(self): self.factory.create_data_source( @@ -165,7 +166,7 @@ def test_connection_bad_delete(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("Couldn't find", result.output) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 1) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 1) def test_options_edit(self): self.factory.create_data_source( @@ -190,7 +191,7 @@ def test_options_edit(self): ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(len(db.session.scalars(select(DataSource)).all()), 1) + self.assertEqual(db.session.scalar(select(func.count(DataSource.id))), 1) ds = db.session.scalar(select(DataSource)) self.assertEqual(ds.name, "test2") self.assertEqual(ds.type, "pg") @@ -240,13 +241,13 @@ def test_bad_options_edit(self): class GroupCommandTests(BaseTestCase): def test_create(self): - gcount = len(db.session.scalars(select(Group)).all()) + gcount = db.session.scalar(select(func.count(Group.id))) perms = ["create_query", "edit_query", "view_query"] runner = CliRunner() result = runner.invoke(manager, ["groups", "create", "test", "--permissions", ",".join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(len(db.session.scalars(select(Group)).all()), gcount + 1) + self.assertEqual(db.session.scalar(select(func.count(Group.id))), gcount + 1) g = db.session.scalar(select(Group).order_by(Group.id.desc())) db.session.add(self.factory.org) self.assertEqual(g.org_id, self.factory.org.id) @@ -360,7 +361,7 @@ def test_create(self): self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - ucount = len(db.session.scalars(select(Organization)).all()) + ucount = db.session.scalar(select(func.count(Organization.id))) self.assertEqual(ucount, 2) @@ -455,20 +456,20 @@ def test_create_bad(self): def test_delete(self): self.factory.create_user(email="foobar@example.com") - ucount = len(db.session.scalars(select(User)).all()) + ucount = db.session.scalar(select(func.count(User.id))) runner = CliRunner() result = runner.invoke(manager, ["users", "delete", "foobar@example.com"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(len(db.session.scalars(select(User).where(User.email == "foobar@example.com")).all()), 0) - self.assertEqual(len(db.session.scalars(select(User)).all()), ucount - 1) + self.assertEqual(db.session.scalar(select(func.count(User.id)).where(User.email == "foobar@example.com")), 0) + self.assertEqual(db.session.scalar(select(func.count(User.id))), ucount - 1) def test_delete_bad(self): - ucount = len(db.session.scalars(select(User)).all()) + ucount = db.session.scalar(select(func.count(User.id))) runner = CliRunner() result = runner.invoke(manager, ["users", "delete", "foobar@example.com"]) self.assertIn("Deleted 0 users", result.output) - self.assertEqual(len(db.session.scalars(select(User)).all()), ucount) + self.assertEqual(db.session.scalar(select(func.count(User.id))), ucount) def test_password(self): self.factory.create_user(email="foobar@example.com") diff --git a/tests/test_handlers.py b/tests/test_handlers.py index da221bc608..16dfba14f2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,6 +1,7 @@ from flask_login import current_user from funcy import project from mock import patch +from sqlalchemy import func from sqlalchemy.sql.expression import select from redash import models, settings @@ -320,4 +321,4 @@ def test_delete(self): models.db.session.add(qs) models.db.session.commit() self.make_request("delete", "/api/query_snippets/1", user=self.factory.user) - self.assertEqual(len(models.db.session.scalars(select(models.QuerySnippet)).all()), 0) + self.assertEqual(models.db.session.scalar(select(func.count(models.QuerySnippet.id))), 0) diff --git a/tests/test_models.py b/tests/test_models.py index 4108fd536f..8ed844a0d6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -193,7 +193,7 @@ def test_outdated_queries_works_with_specific_time_schedule(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.create_scheduled_query(interval="86400", time=half_an_hour_ago.strftime("%H:%M")) query_result = self.factory.create_query_result( - query=query.query_text, + query=query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1), ) query.latest_query_data = query_result @@ -346,7 +346,7 @@ def test_archived_query_doesnt_return_in_all(self): def test_removes_associated_widgets_from_dashboards(self): widget = self.factory.create_widget() - query = widget.visualization.query_rel + query = widget.visualization.query db.session.commit() query.archive() db.session.flush() @@ -361,7 +361,7 @@ def test_removes_scheduling(self): def test_deletes_alerts(self): subscription = self.factory.create_alert_subscription() - query = subscription.alert.query_rel + query = subscription.alert.query db.session.commit() query.archive() db.session.flush() @@ -580,8 +580,8 @@ def _set_up_dashboard_test(d): ) d.q1 = d.factory.create_query(data_source=d.ds1) d.q2 = d.factory.create_query(data_source=d.ds2) - d.v1 = d.factory.create_visualization(query_rel=d.q1) - d.v2 = d.factory.create_visualization(query_rel=d.q2) + d.v1 = d.factory.create_visualization(query=d.q1) + d.v2 = d.factory.create_visualization(query=d.q2) d.w1 = d.factory.create_widget(visualization=d.v1) d.w2 = d.factory.create_widget(visualization=d.v2) d.w3 = d.factory.create_widget(visualization=d.v2, dashboard=d.w2.dashboard)