From 24217d969ecddf0d12a1e6b3d42a88bda0b8898d Mon Sep 17 00:00:00 2001 From: Allen Short Date: Mon, 21 Nov 2016 09:34:44 -0600 Subject: [PATCH 01/80] schema for sqlalchemy, basic test support --- redash/__init__.py | 5 +- redash/admin.py | 35 +- redash/authentication/google_oauth.py | 2 +- redash/cli/database.py | 6 +- redash/cli/users.py | 2 +- redash/handlers/alerts.py | 1 + redash/handlers/base.py | 10 +- redash/handlers/users.py | 4 +- redash/metrics/database.py | 3 - redash/models.py | 786 ++++++++++------------ redash/settings.py | 3 +- requirements.txt | 4 +- tests/__init__.py | 9 +- tests/factories.py | 55 +- tests/models/test_base_versioned_model.py | 60 -- tests/test_models.py | 17 +- 16 files changed, 433 insertions(+), 569 deletions(-) delete mode 100644 tests/models/test_base_versioned_model.py diff --git a/redash/__init__.py b/redash/__init__.py index 3b0a24cae6..593fc0f2bb 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -105,13 +105,12 @@ def create_app(): logging.getLogger().addHandler(sentry_handler) # configure our database - settings.DATABASE_CONFIG.update({'threadlocals': True}) - app.config['DATABASE'] = settings.DATABASE_CONFIG + app.config['SQLALCHEMY_DATABASE_URI'] = settings.SQLALCHEMY_DATABASE_URI app.config.update(settings.all_settings()) provision_app(app) - init_admin(app) db.init_app(app) + init_admin(app) mail.init_app(app) setup_authentication(app) handlers.init_app(app) diff --git a/redash/admin.py b/redash/admin.py index 0084ebc0f7..731f0ffe35 100644 --- a/redash/admin.py +++ b/redash/admin.py @@ -1,10 +1,9 @@ import json from flask_admin import Admin from flask_admin.base import MenuLink -from flask_admin.contrib.peewee import ModelView -from flask_admin.contrib.peewee.form import CustomModelConverter +from flask_admin.contrib.sqla import ModelView +from flask_admin.contrib.sqla.form import AdminModelConverter from flask_admin.form.widgets import DateTimePickerWidget -from playhouse.postgres_ext import ArrayField, DateTimeTZField from wtforms import fields from wtforms.widgets import TextInput @@ -40,29 +39,9 @@ def process_formdata(self, valuelist): self.data = '' -class PgModelConverter(CustomModelConverter): - def __init__(self, view, additional=None): - additional = {ArrayField: self.handle_array_field, - DateTimeTZField: self.handle_datetime_tz_field, - models.JSONField: self.handle_json_field, - } - super(PgModelConverter, self).__init__(view, additional) - self.view = view - - def handle_json_field(self, model, field, **kwargs): - return field.name, JSONTextAreaField(**kwargs) - - def handle_array_field(self, model, field, **kwargs): - return field.name, ArrayListField(**kwargs) - - def handle_datetime_tz_field(self, model, field, **kwargs): - kwargs['widget'] = DateTimePickerWidget() - return field.name, fields.DateTimeField(**kwargs) - - class BaseModelView(ModelView): column_display_pk = True - model_form_converter = PgModelConverter + model_form_converter = AdminModelConverter @require_super_admin def is_accessible(self): @@ -84,12 +63,12 @@ class DashboardModelView(BaseModelView): def init_admin(app): admin = Admin(app, name='re:dash admin', template_mode='bootstrap3') - admin.add_view(QueryModelView(models.Query)) - admin.add_view(QueryResultModelView(models.QueryResult)) - admin.add_view(DashboardModelView(models.Dashboard)) + admin.add_view(QueryModelView(models.Query, models.db.session)) + admin.add_view(QueryResultModelView(models.QueryResult, models.db.session)) + admin.add_view(DashboardModelView(models.Dashboard, models.db.session)) logout_link = MenuLink('Logout', '/logout', 'logout') for m in (models.Visualization, models.Widget, models.Event, models.Organization): - admin.add_view(BaseModelView(m)) + admin.add_view(BaseModelView(m, models.db.session)) admin.add_link(logout_link) diff --git a/redash/authentication/google_oauth.py b/redash/authentication/google_oauth.py index 0dba2ea05c..bfda933313 100644 --- a/redash/authentication/google_oauth.py +++ b/redash/authentication/google_oauth.py @@ -65,7 +65,7 @@ def create_and_login_user(org, name, email): user_object.save() except models.User.DoesNotExist: logger.debug("Creating user object (%r)", name) - user_object = models.User.create(org=org, name=name, email=email, groups=[org.default_group.id]) + user_object = models.User.create(org=org, name=name, email=email, group_ids=[org.default_group.id]) login_user(user_object, remember=True) diff --git a/redash/cli/database.py b/redash/cli/database.py index 3df6291fed..0de7e257e0 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -6,10 +6,10 @@ @manager.command() def create_tables(): """Create the database tables.""" - from redash.models import create_db, init_db - - create_db(True, False) + from redash.models import db, create_db, init_db + create_db(True, True) init_db() + db.session.commit() @manager.command() diff --git a/redash/cli/users.py b/redash/cli/users.py index 1b624d4cd8..4f2ffdadfe 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -1,7 +1,7 @@ from sys import exit from click import BOOL, Group, argument, option, prompt -from peewee import IntegrityError +from sqlalchemy.exc import IntegrityError from redash import models from redash.handlers.users import invite_user diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index 714dc8f730..4188889411 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -6,6 +6,7 @@ from redash import models from redash.permissions import require_access, require_admin_or_owner, view_only, require_permission from redash.handlers.base import BaseResource, require_fields, get_object_or_404 +from sqlalchemy.exc import DataError class AlertResource(BaseResource): diff --git a/redash/handlers/base.py b/redash/handlers/base.py index f5c5506ccd..d213a12d68 100644 --- a/redash/handlers/base.py +++ b/redash/handlers/base.py @@ -3,7 +3,8 @@ from flask import Blueprint, current_app, request from flask_login import current_user, login_required from flask_restful import Resource, abort -from peewee import DoesNotExist +from sqlalchemy.exc import DataError + from redash import settings from redash.authentication import current_org from redash.models import ApiUser @@ -67,10 +68,11 @@ def require_fields(req, fields): def get_object_or_404(fn, *args, **kwargs): - try: - return fn(*args, **kwargs) - except DoesNotExist: + rv = fn(*args, **kwargs) + if rv is None: abort(404) + else: + return rv def paginate(query_set, page, page_size, serializer): diff --git a/redash/handlers/users.py b/redash/handlers/users.py index 4bb977b8eb..aa0bac81d3 100644 --- a/redash/handlers/users.py +++ b/redash/handlers/users.py @@ -2,7 +2,7 @@ from flask import request from flask_restful import abort from funcy import project -from peewee import IntegrityError +from sqlalchemy.exc import IntegrityError from redash import models from redash.permissions import require_permission, require_admin_or_owner, is_admin_or_owner, \ @@ -31,7 +31,7 @@ def post(self): user = models.User(org=self.current_org, name=req['name'], email=req['email'], - groups=[self.current_org.default_group.id]) + group_ids=[self.current_org.default_group.id]) try: user.save() diff --git a/redash/metrics/database.py b/redash/metrics/database.py index 2e8a8100e0..a4ce6e312f 100644 --- a/redash/metrics/database.py +++ b/redash/metrics/database.py @@ -2,7 +2,6 @@ import time import logging from playhouse.gfk import Model -import peewee from playhouse.postgres_ext import PostgresqlExtDatabase from redash import statsd_client @@ -63,8 +62,6 @@ def extended_clone(self): peewee.Query._execute = metered_execute peewee.Query.clone = extended_clone -patch_query_execute() - class MeteredModel(Model): @classmethod diff --git a/redash/models.py b/redash/models.py index 18ace0ebc4..114891f393 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1,238 +1,133 @@ -import json -from flask_login import UserMixin, AnonymousUserMixin +import datetime +import functools import hashlib +import itertools +import json import logging import os import threading import time -import datetime -import itertools + from funcy import project -import peewee +from flask_sqlalchemy import SQLAlchemy +from flask.ext.sqlalchemy import SignallingSession +from flask_login import UserMixin, AnonymousUserMixin +from sqlalchemy.dialects import postgresql +from sqlalchemy.event import listens_for +from sqlalchemy.types import TypeDecorator + from passlib.apps import custom_app_context as pwd_context from playhouse.gfk import GFKField, BaseModel from playhouse.postgres_ext import ArrayField, DateTimeTZField -from permissions import has_access, view_only -from redash import utils, settings, redis_connection -from redash.query_runner import get_query_runner, get_configuration_schema_for_query_runner_type + + +from redash import redis_connection, settings, utils from redash.destinations import get_destination, get_configuration_schema_for_destination_type from redash.metrics.database import MeteredPostgresqlExtDatabase, MeteredModel +from redash.permissions import has_access, view_only +from redash.query_runner import get_query_runner, get_configuration_schema_for_query_runner_type from redash.utils import generate_token, json_dumps from redash.utils.configuration import ConfigurationContainer +db = SQLAlchemy() +Column = functools.partial(db.Column, nullable=False) -class Database(object): - def __init__(self): - self.database_config = dict(settings.DATABASE_CONFIG) - self.database_config['register_hstore'] = False - self.database_name = self.database_config.pop('name') - self.database = MeteredPostgresqlExtDatabase(self.database_name, **self.database_config) - self.app = None - self.pid = os.getpid() - - def init_app(self, app): - self.app = app - self.register_handlers() - - def connect_db(self): - self._check_pid() - self.database.reset_metrics() - self.database.connect() - - def close_db(self, exc): - self._check_pid() - if not self.database.is_closed(): - self.database.close() +# AccessPermission and Change use a 'generic foreign key' approach to refer to +# either queries or dashboards. +# TODO replace this with association tables. +_gfk_types = {} - def _check_pid(self): - current_pid = os.getpid() - if self.pid != current_pid: - logging.info("New pid detected (%d!=%d); resetting database lock.", self.pid, current_pid) - self.pid = os.getpid() - self.database._conn_lock = threading.Lock() +class GFKBase(object): + """ + Compatibility with 'generic foreign key' approach Peewee used. + """ + # XXX Replace this with table-per-association. + object_type = Column(db.String(255)) + object_id = Column(db.Integer) - def register_handlers(self): - self.app.before_request(self.connect_db) - self.app.teardown_request(self.close_db) + _object = None + @property + def object(self): + session = object_session(self) + if self._object or not session: + return self._object + else: + object_class = _gfk_types[self.object_type] + self._object = session.query(object_class).filter( + object_class.id == self.object_id).first() + return self._object -db = Database() + @object.setter + def object(self, value): + self._object = value + self.object_type = value.__class__.__tablename__ + self.object_id = value.id -# Support for cast operation on database fields -@peewee.Node.extend() -def cast(self, as_type): - return peewee.Expression(self, '::', peewee.SQL(as_type)) +# # Support for cast operation on database fields +# @peewee.Node.extend() +# def cast(self, as_type): +# return peewee.Expression(self, '::', peewee.SQL(as_type)) -class JSONField(peewee.TextField): - def db_value(self, value): +class PseudoJSON(TypeDecorator): + impl = db.Text + def process_bind_param(self, value, dialect): return json_dumps(value) - - def python_value(self, value): + def process_result_value(self, value, dialect): if not value: return value return json.loads(value) -class BaseModel(MeteredModel): - class Meta: - database = db.database - - @classmethod - def get_by_id(cls, model_id): - return cls.get(cls.id == model_id) - - def pre_save(self, created): - pass - - def post_save(self, created): - # Handler for post_save operations. Overriding if needed. - pass - - def save(self, *args, **kwargs): - pk_value = self._get_pk_value() - created = kwargs.get('force_insert', False) or not bool(pk_value) - self.pre_save(created) - super(BaseModel, self).save(*args, **kwargs) - self.post_save(created) - - def update_instance(self, **kwargs): - for k, v in kwargs.items(): - # setattr(model_instance, field_name, field_obj.python_value(value)) - setattr(self, k, v) - - # We have to run pre-save before calculating dirty_fields. We end up running it twice, - # but pre_save calls should be very quick so it's not big of an issue. - # An alternative can be to recalculate dirty_fields, but it felt more error prone. - self.pre_save(False) - - self.save(only=self.dirty_fields) - - -class ModelTimestampsMixin(BaseModel): - updated_at = DateTimeTZField(default=datetime.datetime.now) - created_at = DateTimeTZField(default=datetime.datetime.now) - - def pre_save(self, created): - super(ModelTimestampsMixin, self).pre_save(created) - self.updated_at = datetime.datetime.now() - - -def _simple_value(v): - if isinstance(v, BaseModel): - return v.id - - return v +class TimestampMixin(object): + updated_at = Column(db.DateTime(True), default=db.func.now(), + onupdate=db.func.now(), nullable=False) + created_at = Column(db.DateTime(True), default=db.func.now(), + nullable=False) class ChangeTrackingMixin(object): skipped_fields = ('id', 'created_at', 'updated_at', 'version') + _clean_values = None - def prepared(self): - super(ChangeTrackingMixin, self).prepared() - - setattr(self, '_clean_values', {}) + def prep_cleanvalues(self): + self.__dict__['_clean_values'] = {} + for c in self.__class__.__table__.c: + self._clean_values[c.name] = None def __setattr__(self, key, value): - if hasattr(self, '_clean_values') and key in self._field_names(): + if self._clean_values is None: + self.prep_cleanvalues() + if key in self._clean_values: previous = getattr(self, key) self._clean_values[key] = previous super(ChangeTrackingMixin, self).__setattr__(key, value) - @property - def changes(self): + def record_changes(self, session, changed_by): changes = {} - - if not hasattr(self, '_clean_values'): - setattr(self, '_clean_values', {}) - for field in self._meta.get_fields(): - self._clean_values[field] = None - for k, v in self._clean_values.iteritems(): if k not in self.skipped_fields: - changes[k] = {'previous': _simple_value(v), 'current': _simple_value(getattr(self, k))} - - return changes - - def save(self, *args, **kwargs): - changed_by = kwargs.pop('changed_by', None) - pk_value = self._get_pk_value() - created = kwargs.get('force_insert', False) or not bool(pk_value) - - if created and changed_by is None: - changed_by = self.user - - ret = super(ChangeTrackingMixin, self).save(*args, **kwargs) - - if changed_by: - Change.log_change(changed_by, self) - self._clean_values = {} - - return ret - - def update_instance(self, **kwargs): - changed_by = kwargs.pop('changed_by', None) - ret = super(ChangeTrackingMixin, self).update_instance(**kwargs) - if changed_by: - Change.log_change(changed_by, self) - return ret - - def _field_names(self): - return [f.name for f in self._meta.get_fields()] - + changes[k] = {'previous': v, 'current': getattr(self, k)} + session.add(Change(object_type=self.__class__.__tablename__, + object_id=self.id, + object_version=self.version, + user_id=changed_by.id, + change=changes)) + session.add(self) class ConflictDetectedError(Exception): pass - -class BaseVersionedModel(BaseModel): - version = peewee.IntegerField(default=1) - - def save(self, *args, **kwargs): - pk_value = self._get_pk_value() - created = kwargs.get('force_insert', False) or not bool(pk_value) - - if created: - # Since this is an `INSERT`, just call regular save method. - return super(BaseVersionedModel, self).save() - - # Update any data that has changed and bump the version counter. - self.pre_save(False) - - field_data = dict(self._data) - current_version = field_data.pop('version', 0) - field_data = self._prune_fields(field_data, self.dirty_fields) - - # if not field_data: - # raise ValueError('No changes have been made.') - - ModelClass = type(self) - field_data['version'] = ModelClass.version + 1 # Atomic increment - - query = ModelClass.update(**field_data).where( - (ModelClass.version == current_version) & - (ModelClass.id == self.id)) - - nrows = query.execute() - if nrows == 0: - # It looks like another process has updated the version number. - raise ConflictDetectedError() # Raise exception? Return False? - else: - self.version += 1 # Update in-memory version number. - self._dirty.clear() - self.post_save(False) - return nrows - - class BelongsToOrgMixin(object): @classmethod def get_by_id_and_org(cls, object_id, org): - return cls.get(cls.id == object_id, cls.org == org) + return cls.query.filter(cls.id == object_id, cls.org == org).first() class PermissionsCheckMixin(object): @@ -277,33 +172,27 @@ def has_access(self, obj, access_type): return False -class Organization(ModelTimestampsMixin, BaseModel): +class Organization(TimestampMixin, db.Model): SETTING_GOOGLE_APPS_DOMAINS = 'google_apps_domains' SETTING_IS_PUBLIC = "is_public" - id = peewee.PrimaryKeyField() - name = peewee.CharField() - slug = peewee.CharField(unique=True) - settings = JSONField() + id = Column(db.Integer, primary_key=True) + name = Column(db.String(255)) + slug = Column(db.String(255), unique=True) + settings = Column(PseudoJSON) - class Meta: - db_table = 'organizations' + __tablename__ = 'organizations' def __repr__(self): return u"".format(self.id, self.name) - # When Organization is used with LocalProxy (like the current_org helper), peewee doesn't recognize it as a Model - # and might call int() on it. This method makes sure it works. - def __int__(self): - return self.id - @classmethod def get_by_slug(cls, slug): - return cls.get(cls.slug == slug) + return cls.query.filter(cls.slug == slug).first() @property def default_group(self): - return self.groups.where(Group.name=='default', Group.type==Group.BUILTIN_GROUP).first() + return self.groups.filter(Group.name == 'default', Group.type == Group.BUILTIN_GROUP).first() @property def google_apps_domains(self): @@ -315,13 +204,13 @@ def is_public(self): @property def admin_group(self): - return self.groups.where(Group.name=='admin', Group.type==Group.BUILTIN_GROUP).first() + return self.groups.filter(Group.name == 'admin', Group.type == Group.BUILTIN_GROUP).first() def has_user(self, email): - return self.users.where(User.email==email).count() == 1 + return self.users.filter(User.email == email).count() == 1 -class Group(BaseModel, BelongsToOrgMixin): +class Group(db.Model, BelongsToOrgMixin): DEFAULT_PERMISSIONS = ['create_dashboard', 'create_query', 'edit_dashboard', 'edit_query', 'view_query', 'view_source', 'execute_query', 'list_users', 'schedule_query', 'list_dashboards', 'list_alerts', 'list_data_sources'] @@ -329,15 +218,16 @@ class Group(BaseModel, BelongsToOrgMixin): BUILTIN_GROUP = 'builtin' REGULAR_GROUP = 'regular' - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="groups") - type = peewee.CharField(default=REGULAR_GROUP) - name = peewee.CharField(max_length=100) - permissions = ArrayField(peewee.CharField, default=DEFAULT_PERMISSIONS) - created_at = DateTimeTZField(default=datetime.datetime.now) + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship(Organization, backref="groups") + type = Column(db.String(255), default=REGULAR_GROUP) + name = Column(db.String(100)) + permissions = Column(postgresql.ARRAY(db.String(255)), + default=DEFAULT_PERMISSIONS) + created_at = Column(db.DateTime(True), default=db.func.now()) - class Meta: - db_table = 'groups' + __tablename__ = 'groups' def to_dict(self): return { @@ -350,36 +240,40 @@ def to_dict(self): @classmethod def all(cls, org): - return cls.select().where(cls.org==org) + return cls.query.filter(cls.org == org) @classmethod def members(cls, group_id): - return User.select().where(peewee.SQL("%s = ANY(groups)", group_id)) + return User.query.filter(group_id == db.func.any_(User.c.groups)) @classmethod def find_by_name(cls, org, group_names): - result = cls.select().where(cls.org == org, cls.name << group_names) + result = cls.query.filter(cls.org == org, cls.name.in_(group_names)) return list(result) def __unicode__(self): return unicode(self.id) - -class User(ModelTimestampsMixin, BaseModel, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="users") - name = peewee.CharField(max_length=320) - email = peewee.CharField(max_length=320) - password_hash = peewee.CharField(max_length=128, null=True) - groups = ArrayField(peewee.IntegerField, null=True) - api_key = peewee.CharField(max_length=40, unique=True) - - class Meta: - db_table = 'users' - - indexes = ( - (('org', 'email'), True), - ) +def create_group_hack(*a, **kw): + g = Group(*a, **kw) + db.session.add(g) + db.commit() + return g.id + +class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship(Organization, backref="users") + name = Column(db.String(320)) + email = Column(db.String(320)) + password_hash = Column(db.String(128), nullable=True) + group_ids = Column('groups', postgresql.ARRAY(db.Integer), nullable=True) + api_key = Column(db.String(40), + default=lambda: generate_token(40), + unique=True) + + __tablename__ = 'users' + __table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),) def __init__(self, *args, **kwargs): super(User, self).__init__(*args, **kwargs) @@ -405,12 +299,6 @@ def to_dict(self, with_api_key=False): return d - def pre_save(self, created): - super(User, self).pre_save(created) - - if not self.api_key: - self.api_key = generate_token(40) - @property def gravatar_url(self): email_md5 = hashlib.md5(self.email.lower()).hexdigest() @@ -457,30 +345,31 @@ def has_access(self, obj, access_type): return AccessPermission.exists(obj, access_type, grantee=self) -class ConfigurationField(peewee.TextField): - def db_value(self, value): +class Configuration(TypeDecorator): + + impl = db.Text + + def process_bind_param(self, value, dialect): return value.to_json() - def python_value(self, value): + def process_result_value(self, value, dialect): return ConfigurationContainer.from_json(value) -class DataSource(BelongsToOrgMixin, BaseModel): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="data_sources") - name = peewee.CharField() - type = peewee.CharField() - options = ConfigurationField() - queue_name = peewee.CharField(default="queries") - scheduled_queue_name = peewee.CharField(default="scheduled_queries") - created_at = DateTimeTZField(default=datetime.datetime.now) +class DataSource(BelongsToOrgMixin, db.Model): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship(Organization, backref="data_sources") - class Meta: - db_table = 'data_sources' + name = Column(db.String(255)) + type = Column(db.String(255)) + options = Column(Configuration) + queue_name = Column(db.String(255), default="queries") + scheduled_queue_name = Column(db.String(255), default="scheduled_queries") + created_at = Column(db.DateTime(True), default=db.func.now()) - indexes = ( - (('org', 'name'), True), - ) + __tablename__ = 'data_sources' + __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) def to_dict(self, all=False, with_permissions=False): d = { @@ -580,27 +469,30 @@ def groups(self): return dict(map(lambda g: (g.group_id, g.view_only), groups)) -class DataSourceGroup(BaseModel): - data_source = peewee.ForeignKeyField(DataSource) - group = peewee.ForeignKeyField(Group, related_name="data_sources") - view_only = peewee.BooleanField(default=False) +class DataSourceGroup(db.Model): + id = Column(db.Integer, primary_key=True) + data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) + data_source = db.relationship(DataSource) + group_id = Column(db.Integer, db.ForeignKey("groups.id")) + group = db.relationship(Group, backref="data_sources") + view_only = Column(db.Boolean, default=False) - class Meta: - db_table = "data_source_groups" + __tablename__ = "data_source_groups" -class QueryResult(BaseModel, BelongsToOrgMixin): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization) - data_source = peewee.ForeignKeyField(DataSource) - query_hash = peewee.CharField(max_length=32, index=True) - query = peewee.TextField() - data = peewee.TextField() - runtime = peewee.FloatField() - retrieved_at = DateTimeTZField() +class QueryResult(db.Model, BelongsToOrgMixin): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship(Organization) + data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) + data_source = db.relationship(DataSource) + query_hash = Column(db.String(32), index=True) + query = Column(db.Text) + data = Column(db.Text) + runtime = Column(postgresql.DOUBLE_PRECISION) + retrieved_at = Column(db.DateTime(True)) - class Meta: - db_table = 'query_results' + __tablename__ = 'query_results' def to_dict(self): return { @@ -689,25 +581,46 @@ def should_schedule_next(previous_iteration, now, schedule): return now > next_iteration -class Query(ChangeTrackingMixin, ModelTimestampsMixin, BaseVersionedModel, BelongsToOrgMixin): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="queries") - data_source = peewee.ForeignKeyField(DataSource, null=True) - latest_query_data = peewee.ForeignKeyField(QueryResult, null=True) - name = peewee.CharField(max_length=255) - description = peewee.CharField(max_length=4096, null=True) - query = peewee.TextField() - query_hash = peewee.CharField(max_length=32) - api_key = peewee.CharField(max_length=40) - user = peewee.ForeignKeyField(User) - last_modified_by = peewee.ForeignKeyField(User, null=True, related_name="modified_queries") - is_archived = peewee.BooleanField(default=False, index=True) - is_draft = peewee.BooleanField(default=True, index=True) - schedule = peewee.CharField(max_length=10, null=True) - options = JSONField(default={}) - - class Meta: - db_table = 'queries' +def generate_query_api_key(ctx): + return hashlib.sha1(u''.join(( + str(time.time()), ctx.current_parameters['query'], + str(ctx.current_parameters['user_id']), + ctx.current_parameters['name'])).encode('utf-8')).hexdigest() + +def gen_query_hash(ctx): + return utils.gen_query_hash(ctx.current_parameters['query']) + +class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): + id = Column(db.Integer, primary_key=True) + version = Column(db.Integer) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship(Organization, backref="queries") + data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id"), nullable=True) + data_source = db.relationship(DataSource) + latest_query_data_id = Column(db.Integer, db.ForeignKey("query_results.id"), nullable=True) + latest_query_data = db.relationship(QueryResult) + name = Column(db.String(255)) + description = Column(db.String(4096), nullable=True) + query = Column(db.Text) + query_hash = Column(db.String(32), + default=gen_query_hash, + onupdate=gen_query_hash) + api_key = Column(db.String(40), default=generate_query_api_key) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User, foreign_keys=[user_id]) + last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True, + onupdate=lambda ctx: ctx.current_parameters['user_id']) + last_modified_by = db.relationship(User, backref="modified_queries", + foreign_keys=[last_modified_by_id]) + is_archived = Column(db.Boolean, default=False, index=True) + is_draft = Column(db.Boolean, default=True, index=True) + schedule = Column(db.String(10), nullable=True) + options = Column(PseudoJSON, default={}) + + __tablename__ = 'queries' + __mapper_args__ = { + "version_id_col": version + } def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True): d = { @@ -843,6 +756,7 @@ def recent(cls, groups, user_id=None, limit=20): return query +<<<<<<< 196177021c6d0b0ccd50ecbb19ca6fef1ca7160a def fork(self, user): query = self forked_query = Query() @@ -923,18 +837,26 @@ def groups(self): def __unicode__(self): return unicode(self.id) +@listens_for(SignallingSession, 'before_flush') +def create_default_visualizations(session, ctx, *a): + for obj in session.new: + if isinstance(obj, Query): + session.add(Visualization(query=obj, name="Table", + description='', + type="TABLE", options="{}")) + + -class AccessPermission(BaseModel): - id = peewee.PrimaryKeyField() - object_type = peewee.CharField(index=True) - object_id = peewee.IntegerField(index=True) - object = GFKField('object_type', 'object_id') - access_type = peewee.CharField() - grantor = peewee.ForeignKeyField(User, related_name='grantor') - grantee = peewee.ForeignKeyField(User, related_name='grantee') +class AccessPermission(GFKBase, db.Model): + id = Column(db.Integer, primary_key=True) + # 'object' defined in GFKBase + access_type = Column(db.String(255)) + grantor_id = Column(db.Integer, db.ForeignKey("users.id")) + grantor = db.relationship(User, backref='grantor', foreign_keys=[grantor_id]) + grantee_id = Column(db.Integer, db.ForeignKey("users.id")) + grantee = db.relationship(User, backref='grantee', foreign_keys=[grantee_id]) - class Meta: - db_table = 'access_permissions' + __tablename__ = 'access_permissions' @classmethod def grant(cls, obj, access_type, grantee, grantor): @@ -982,18 +904,16 @@ def to_dict(self): return d -class Change(BaseModel): - id = peewee.PrimaryKeyField() - object_id = peewee.CharField(index=True) - object_type = peewee.CharField(index=True) - object_version = peewee.IntegerField(default=0) - object = GFKField('object_type', 'object_id') - user = peewee.ForeignKeyField(User, related_name='changes') - change = JSONField() - created_at = DateTimeTZField(default=datetime.datetime.now) +class Change(GFKBase, db.Model): + id = Column(db.Integer, primary_key=True) + # 'object' defined in GFKBase + object_version = Column(db.Integer, default=0) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User, backref='changes') + change = Column(PseudoJSON) + created_at = Column(db.DateTime(True), default=db.func.now()) - class Meta: - db_table = 'changes' + __tablename__ = 'changes' def to_dict(self, full=True): d = { @@ -1022,22 +942,23 @@ def last_change(cls, obj): return cls.select().where(cls.object_type==obj._meta.db_table, cls.object_id==obj.id).limit(1).first() -class Alert(ModelTimestampsMixin, BaseModel): +class Alert(TimestampMixin, db.Model): UNKNOWN_STATE = 'unknown' OK_STATE = 'ok' TRIGGERED_STATE = 'triggered' - id = peewee.PrimaryKeyField() - name = peewee.CharField() - query = peewee.ForeignKeyField(Query, related_name='alerts') - user = peewee.ForeignKeyField(User, related_name='alerts') - options = JSONField() - state = peewee.CharField(default=UNKNOWN_STATE) - last_triggered_at = DateTimeTZField(null=True) - rearm = peewee.IntegerField(null=True) + id = Column(db.Integer, primary_key=True) + name = Column(db.String(255)) + query_id = Column(db.Integer, db.ForeignKey("queries.id")) + query = db.relationship(Query, backref='alerts') + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User, backref='alerts') + options = Column(PseudoJSON) + state = Column(db.String(255), default=UNKNOWN_STATE) + last_triggered_at = Column(db.DateTime(True), nullable=True) + rearm = Column(db.Integer, nullable=True) - class Meta: - db_table = 'alerts' + __tablename__ = 'alerts' @classmethod def all(cls, groups): @@ -1099,19 +1020,33 @@ def groups(self): return self.query.groups -class Dashboard(ChangeTrackingMixin, ModelTimestampsMixin, BaseVersionedModel, BelongsToOrgMixin): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="dashboards") - slug = peewee.CharField(max_length=140, index=True) - name = peewee.CharField(max_length=100) - user = peewee.ForeignKeyField(User) - layout = peewee.TextField() - dashboard_filters_enabled = peewee.BooleanField(default=False) - is_archived = peewee.BooleanField(default=False, index=True) - is_draft = peewee.BooleanField(default=False, index=True) - - class Meta: - db_table = 'dashboards' +def generate_slug(ctx): + slug = utils.slugify(ctx.current_parameters['name']) + tries = 1 + while db.session.query(Dashboard).filter(Dashboard.slug == slug).first() is not None: + slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries) + tries += 1 + return slug + + +class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): + id = Column(db.Integer, primary_key=True) + version = Column(db.Integer) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) + org = db.relationship(Organization, backref="dashboards") + slug = Column(db.String(140), index=True, default=generate_slug) + name = Column(db.String(100)) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User) + layout = Column(db.Text) + dashboard_filters_enabled = Column(db.Boolean, default=False) + is_archived = Column(db.Boolean, default=False, index=True) + is_draft = Column(db.Boolean, default=True, index=True) + + __tablename__ = 'dashboards' + __mapper_args__ = { + "version_id_col": version + } def to_dict(self, with_widgets=False, user=None): layout = json.loads(self.layout) @@ -1225,31 +1160,21 @@ def tracked_save(self, changing_user, old_object=None, *args, **kwargs): new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self) return new_change - def save(self, *args, **kwargs): - if not self.slug: - self.slug = utils.slugify(self.name) - - tries = 1 - while self.select().where(Dashboard.slug == self.slug).first() is not None: - self.slug = utils.slugify(self.name) + "_{0}".format(tries) - tries += 1 - - super(Dashboard, self).save(*args, **kwargs) def __unicode__(self): return u"%s=%s" % (self.id, self.name) -class Visualization(ModelTimestampsMixin, BaseModel): - id = peewee.PrimaryKeyField() - type = peewee.CharField(max_length=100) - query = peewee.ForeignKeyField(Query, related_name='visualizations') - name = peewee.CharField(max_length=255) - description = peewee.CharField(max_length=4096, null=True) - options = peewee.TextField() +class Visualization(TimestampMixin, db.Model): + id = Column(db.Integer, primary_key=True) + type = Column(db.String(100)) + query_id = Column(db.Integer, db.ForeignKey("queries.id")) + query = db.relationship(Query, backref='visualizations') + name = Column(db.String(255)) + description = Column(db.String(4096), nullable=True) + options = Column(db.Text) - class Meta: - db_table = 'visualizations' + __tablename__ = 'visualizations' def to_dict(self, with_query=True): d = { @@ -1276,20 +1201,21 @@ def __unicode__(self): return u"%s %s" % (self.id, self.type) -class Widget(ModelTimestampsMixin, BaseModel): - id = peewee.PrimaryKeyField() - visualization = peewee.ForeignKeyField(Visualization, related_name='widgets', null=True) - text = peewee.TextField(null=True) - width = peewee.IntegerField() - options = peewee.TextField() - dashboard = peewee.ForeignKeyField(Dashboard, related_name='widgets', index=True) +class Widget(TimestampMixin, db.Model): + id = Column(db.Integer, primary_key=True) + visualization_id = Column(db.Integer, db.ForeignKey('visualizations.id'), nullable=True) + visualization = db.relationship(Visualization, backref='widgets') + text = Column(db.Text, nullable=True) + width = Column(db.Integer) + options = Column(db.Text) + dashboard_id = Column(db.Integer, db.ForeignKey("dashboards.id"), index=True) + dashboard = db.relationship(Dashboard, backref='widgets') # unused; kept for backward compatability: - type = peewee.CharField(max_length=100, null=True) - query_id = peewee.IntegerField(null=True) + type = Column(db.String(100), nullable=True) + query_id = Column(db.Integer, nullable=True) - class Meta: - db_table = 'widgets' + __tablename__ = 'widgets' def to_dict(self): d = { @@ -1323,17 +1249,19 @@ def delete_instance(self, *args, **kwargs): super(Widget, self).delete_instance(*args, **kwargs) -class Event(BaseModel): - org = peewee.ForeignKeyField(Organization, related_name="events") - user = peewee.ForeignKeyField(User, related_name="events", null=True) - action = peewee.CharField() - object_type = peewee.CharField() - object_id = peewee.CharField(null=True) - additional_properties = peewee.TextField(null=True) - created_at = DateTimeTZField(default=datetime.datetime.now) +class Event(db.Model): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) + org = db.relationship(Organization, backref="events") + user_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) + user = db.relationship(User, backref="events") + action = Column(db.String(255)) + object_type = Column(db.String(255)) + object_id = Column(db.String(255), nullable=True) + additional_properties = Column(db.Text, nullable=True) + created_at = Column(db.DateTime(True), default=db.func.now()) - class Meta: - db_table = 'events' + __tablename__ = 'events' def __unicode__(self): return u"%s,%s,%s,%s" % (self.user_id, self.action, self.object_type, self.object_id) @@ -1354,21 +1282,18 @@ def record(cls, event): return event +class ApiKey(TimestampMixin, GFKBase, db.Model): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) + org = db.relationship(Organization) + api_key = Column(db.String(255), index=True, default=lambda: generate_token(40)) + active = Column(db.Boolean, default=True) + #'object' provided by GFKBase + created_by_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) + created_by = db.relationship(User) -class ApiKey(ModelTimestampsMixin, BaseModel): - org = peewee.ForeignKeyField(Organization) - api_key = peewee.CharField(index=True, default=lambda: generate_token(40)) - active = peewee.BooleanField(default=True) - object_type = peewee.CharField() - object_id = peewee.IntegerField() - object = GFKField('object_type', 'object_id') - created_by = peewee.ForeignKeyField(User, null=True) - - class Meta: - db_table = 'api_keys' - indexes = ( - (('object_type', 'object_id'), False), - ) + __tablename__ = 'api_keys' + __table_args__ = (db.Index('api_keys_object_type_object_id', 'object_type', 'object_id'),) @classmethod def get_by_api_key(cls, api_key): @@ -1383,22 +1308,20 @@ def create_for_object(cls, object, user): return cls.create(org=user.org, object=object, created_by=user) -class NotificationDestination(BelongsToOrgMixin, BaseModel): - - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="notification_destinations") - user = peewee.ForeignKeyField(User, related_name="notification_destinations") - name = peewee.CharField() - type = peewee.CharField() - options = ConfigurationField() - created_at = DateTimeTZField(default=datetime.datetime.now) +class NotificationDestination(BelongsToOrgMixin, db.Model): - class Meta: - db_table = 'notification_destinations' - - indexes = ( - (('org', 'name'), True), - ) + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) + org = db.relationship(Organization, backref="notification_destinations") + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User, backref="notification_destinations") + name = Column(db.String(255)) + type = Column(db.String(255)) + options = Column(Configuration) + created_at = Column(db.DateTime(True), default=db.func.now()) + __tablename__ = 'notification_destinations' + __table_args__ = (db.Index('notification_destinations_org_id_name', 'org_id', + 'name', unique=True),) def to_dict(self, all=False): d = { @@ -1435,17 +1358,20 @@ def notify(self, alert, query, user, new_state, app, host): app, host, self.options) -class AlertSubscription(ModelTimestampsMixin, BaseModel): - user = peewee.ForeignKeyField(User) - destination = peewee.ForeignKeyField(NotificationDestination, null=True) - alert = peewee.ForeignKeyField(Alert, related_name="subscriptions") - - class Meta: - db_table = 'alert_subscriptions' +class AlertSubscription(TimestampMixin, db.Model): + id = Column(db.Integer, primary_key=True) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User) + destination_id = Column(db.Integer, + db.ForeignKey("notification_destinations.id"), + nullable=True) + destination = db.relationship(NotificationDestination) + alert_id = Column(db.Integer, db.ForeignKey("alerts.id")) + alert = db.relationship(Alert, backref="subscriptions") - indexes = ( - (('destination', 'alert'), True), - ) + __tablename__ = 'alert_subscriptions' + __table_args__ = (db.Index('alert_subscriptions_destination_id_alert_id', + 'destination_id', 'alert_id', unique=True),) def to_dict(self): d = { @@ -1476,16 +1402,16 @@ def notify(self, alert, query, user, new_state, app, host): return destination.notify(alert, query, user, new_state, app, host, options) -class QuerySnippet(ModelTimestampsMixin, BaseModel, BelongsToOrgMixin): - id = peewee.PrimaryKeyField() - org = peewee.ForeignKeyField(Organization, related_name="query_snippets") - trigger = peewee.CharField(unique=True) - description = peewee.TextField() - user = peewee.ForeignKeyField(User, related_name="query_snippets") - snippet = peewee.TextField() - - class Meta: - db_table = 'query_snippets' +class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) + org = db.relationship(Organization, backref="query_snippets") + trigger = Column(db.String(255), unique=True) + description = Column(db.Text) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship(User, backref="query_snippets") + snippet = Column(db.Text) + __tablename__ = 'query_snippets' @classmethod def all(cls, org): @@ -1504,26 +1430,26 @@ def to_dict(self): return d +_gfk_types = {'queries': Query, 'dashboards': Dashboard} all_models = (Organization, Group, DataSource, DataSourceGroup, User, QueryResult, Query, Alert, Dashboard, Visualization, Widget, Event, NotificationDestination, AlertSubscription, ApiKey, AccessPermission, Change) def init_db(): - default_org = Organization.create(name="Default", slug='default', settings={}) - admin_group = Group.create(name='admin', permissions=['admin', 'super_admin'], org=default_org, type=Group.BUILTIN_GROUP) - default_group = Group.create(name='default', permissions=Group.DEFAULT_PERMISSIONS, org=default_org, type=Group.BUILTIN_GROUP) + default_org = Organization(name="Default", slug='default', settings={}) + admin_group = Group(name='admin', permissions=['admin', 'super_admin'], org=default_org, type=Group.BUILTIN_GROUP) + default_group = Group(name='default', permissions=Group.DEFAULT_PERMISSIONS, org=default_org, type=Group.BUILTIN_GROUP) + db.session.add_all([default_org, admin_group, default_group]) + #XXX remove after fixing User.group_ids + db.session.commit() return default_org, admin_group, default_group def create_db(create_tables, drop_tables): - db.connect_db() - - for model in all_models: - if drop_tables and model.table_exists(): - model.drop_table(cascade=True) - - if create_tables and not model.table_exists(): - model.create_table() + if drop_tables: + db.session.rollback() + db.drop_all() - db.close_db(None) + if create_tables: + db.create_all() diff --git a/redash/settings.py b/redash/settings.py index a42ff0f3dd..f03add2ac5 100644 --- a/redash/settings.py +++ b/redash/settings.py @@ -65,7 +65,8 @@ def all_settings(): STATSD_USE_TAGS = parse_boolean(os.environ.get('REDASH_STATSD_USE_TAGS', "false")) # Connection settings for re:dash's own database (where we store the queries, results, etc) -DATABASE_CONFIG = parse_db_url(os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql://postgres"))) +SQLALCHEMY_DATABASE_URI = os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql://postgres")) +SQLALCHEMY_TRACK_MODIFICATIONS = False # Celery related settings CELERY_BROKER = os.environ.get("REDASH_CELERY_BROKER", REDIS_URL) diff --git a/requirements.txt b/requirements.txt index ac97301855..d509964b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ Flask-Admin==1.1.0 Flask-RESTful==0.3.5 Flask-Login==0.3.2 Flask-OAuthLib==0.9.2 +Flask-SQLAlchemy==2.1 flask-mail==0.9.1 flask-sslify==0.1.5 passlib==1.6.2 @@ -13,13 +14,13 @@ Werkzeug==0.11.3 aniso8601==1.1.0 blinker==1.3 itsdangerous==0.24 -peewee==2.6.1 psycopg2==2.5.2 python-dateutil==2.4.2 pytz==2016.7 redis==2.10.5 requests==2.11.1 six==1.10.0 +SQLAlchemy==1.1.4 sqlparse==0.1.8 wsgiref==0.1.2 honcho==0.5.0 @@ -29,7 +30,6 @@ celery==3.1.23 jsonschema==2.4.0 click==6.6 RestrictedPython==3.6.0 -wtf-peewee==0.2.3 pysaml2==2.4.0 pycrypto==2.6.1 funcy==1.7.1 diff --git a/tests/__init__.py b/tests/__init__.py index 3e76c7ab41..7bd5dac677 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,7 +11,7 @@ import logging from unittest import TestCase import datetime -from redash import settings +from redash import create_app, settings from factories import Factory settings.DATABASE_CONFIG = { @@ -25,17 +25,20 @@ logging.disable("INFO") logging.getLogger("metrics").setLevel("ERROR") -logging.getLogger('peewee').setLevel(logging.INFO) class BaseTestCase(TestCase): def setUp(self): + self.app = create_app() + self.app_ctx = self.app.app_context() + self.app_ctx.push() redash.models.create_db(True, True) self.factory = Factory() + def tearDown(self): - redash.models.db.close_db(None) redash.models.create_db(False, True) + self.app_ctx.pop() redis_connection.flushdb() def make_request(self, method, path, org=None, user=None, data=None, is_json=True): diff --git a/tests/factories.py b/tests/factories.py index fe6c5a1b9e..0a9ec2841d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,4 +1,5 @@ import redash.models +from redash.models import db from redash.utils import gen_query_hash, utcnow from redash.utils.configuration import ConfigurationContainer from redash.permissions import ACCESS_TYPE_MODIFY @@ -19,14 +20,11 @@ def _get_kwargs(self, override_kwargs): return kwargs - def instance(self, **override_kwargs): - kwargs = self._get_kwargs(override_kwargs) - - return self.model(**kwargs) - def create(self, **override_kwargs): kwargs = self._get_kwargs(override_kwargs) - return self.model.create(**kwargs) + obj = self.model(**kwargs) + db.session.add(obj) + return obj class Sequence(object): @@ -42,8 +40,8 @@ def __call__(self): user_factory = ModelFactory(redash.models.User, name='John Doe', email=Sequence('test{}@example.com'), - groups=[2], - org=1) + group_ids=[2], + org_id=1) org_factory = ModelFactory(redash.models.Organization, name=Sequence("Org {}"), @@ -151,31 +149,33 @@ def user(self): def data_source(self): if self._data_source is None: self._data_source = data_source_factory.create(org=self.org) - redash.models.DataSourceGroup.create(group=self.default_group, data_source=self._data_source) + db.session.add(redash.models.DataSourceGroup( + group=self.default_group, + data_source=self._data_source)) return self._data_source def _init_org(self): if self._org is None: self._org, self._admin_group, self._default_group = redash.models.init_db() - self.org.update_instance(domain='org0.example.org') + self.org.domain = 'org0.example.org' def create_org(self, **kwargs): org = org_factory.create(**kwargs) - self.create_group(org=org, type=redash.models.Group.BUILTIN_GROUP, name="default") - self.create_group(org=org, type=redash.models.Group.BUILTIN_GROUP, name="admin", permissions=["admin"]) + self.create_group(org=org, type=redash.models.Group.BUILTIN_GROUP, name="admin", + permissions=["admin"]) return org def create_user(self, **kwargs): args = { 'org': self.org, - 'groups': [self.default_group.id] + 'group_ids': [self.default_group.id] } if 'org' in kwargs: - args['groups'] = [kwargs['org'].default_group.id] + args['group_ids'] = [kwargs['org'].default_group.id] args.update(kwargs) return user_factory.create(**args) @@ -183,11 +183,11 @@ def create_user(self, **kwargs): def create_admin(self, **kwargs): args = { 'org': self.org, - 'groups': [self.admin_group.id, self.default_group.id] + 'group_ids': [self.admin_group.id, self.default_group.id] } if 'org' in kwargs: - args['groups'] = [kwargs['org'].default_group.id, kwargs['org'].admin_group.id] + args['group_ids'] = [kwargs['org'].default_group.id, kwargs['org'].admin_group.id] args.update(kwargs) return user_factory.create(**args) @@ -200,7 +200,19 @@ def create_group(self, **kwargs): args.update(kwargs) - return redash.models.Group.create(**args) + g = redash.models.Group(**args) + return g + + def create_group_hack(self, **kwargs): + args = { + 'name': 'Group', + 'org': self.org + } + + args.update(kwargs) + + g_id = redash.models.create_group_hack(**args) + return g_id def create_alert(self, **kwargs): args = { @@ -231,12 +243,13 @@ def create_data_source(self, **kwargs): data_source = data_source_factory.create(**args) - if 'group' in kwargs: + if 'group_id' in kwargs: view_only = kwargs.pop('view_only', False) - redash.models.DataSourceGroup.create(group=kwargs['group'], - data_source=data_source, - view_only=view_only) + db.session.add(redash.models.DataSourceGroup( + group_id=kwargs['group_id'], + data_source=data_source, + view_only=view_only)) return data_source diff --git a/tests/models/test_base_versioned_model.py b/tests/models/test_base_versioned_model.py deleted file mode 100644 index 9aa2a3255b..0000000000 --- a/tests/models/test_base_versioned_model.py +++ /dev/null @@ -1,60 +0,0 @@ -import peewee - -from mock import patch -from tests import BaseTestCase -from redash.models import ChangeTrackingMixin, BaseVersionedModel, ConflictDetectedError - - -class TestModel(BaseVersionedModel): - value = peewee.IntegerField() - - class Meta: - db_table = 'test_mode' - - -class TestModelTestCase(BaseTestCase): - def setUp(self): - super(TestModelTestCase, self).setUp() - TestModel.create_table() - - def tearDown(self): - super(TestModelTestCase, self).tearDown() - TestModel.drop_table() - - -class TestBaseVersionedModel(TestModelTestCase): - def test_creates_first_instance_with_version_0(self): - t = TestModel(value=123) - t.save() - - self.assertIsNotNone(t.id) - self.assertEqual(t.version, 1) - self.assertEqual(t.value, 123) - - def test_fails_when_there_is_version_conflict(self): - t = TestModel(value=123) - t.save() - - t1 = TestModel.get(TestModel.id==t.id) - t2 = TestModel.get(TestModel.id==t.id) - - t1.value = 124 - t1.save() - - self.assertRaises(ConflictDetectedError, lambda: t2.save()) - - def test_calls_save_hooks(self): - t = TestModel(value=123) - - with patch(__name__ + '.TestModel.pre_save') as pre_save_mock, patch(__name__ + '.TestModel.post_save') as post_save_mock: - t.save() - - pre_save_mock.assert_called_once_with(True) - post_save_mock.assert_called_once_with(True) - - t.value = 124 - with patch(__name__ + '.TestModel.pre_save') as pre_save_mock, patch(__name__ + '.TestModel.post_save') as post_save_mock: - t.save() - - pre_save_mock.assert_called_once_with(False) - post_save_mock.assert_called_once_with(False) diff --git a/tests/test_models.py b/tests/test_models.py index 18eeaeacfe..e1eab88614 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,18 +6,22 @@ from dateutil.parser import parse as date_parse from tests import BaseTestCase from redash import models +from redash.models import db from redash.utils import gen_query_hash, utcnow class DashboardTest(BaseTestCase): def test_appends_suffix_to_slug_when_duplicate(self): d1 = self.factory.create_dashboard() + db.session.flush() self.assertEquals(d1.slug, 'test') d2 = self.factory.create_dashboard(user=d1.user) + db.session.flush() self.assertNotEquals(d1.slug, d2.slug) d3 = self.factory.create_dashboard(user=d1.user) + db.session.flush() self.assertNotEquals(d1.slug, d3.slug) self.assertNotEquals(d2.slug, d3.slug) @@ -25,19 +29,18 @@ def test_appends_suffix_to_slug_when_duplicate(self): class QueryTest(BaseTestCase): def test_changing_query_text_changes_hash(self): q = self.factory.create_query() - old_hash = q.query_hash - q.update_instance(query="SELECT 2;") - - q = models.Query.get_by_id(q.id) + q.query = "SELECT 2;" + #q = db.session.query(models.Query).get(q.id) + db.session.flush() self.assertNotEquals(old_hash, q.query_hash) def test_search_finds_in_name(self): q1 = self.factory.create_query(name=u"Testing seåřċħ") q2 = self.factory.create_query(name=u"Testing seåřċħing") q3 = self.factory.create_query(name=u"Testing seå řċħ") - + db.session.flush() queries = models.Query.search(u"seåřċħ", [self.factory.default_group]) self.assertIn(q1, queries) @@ -599,8 +602,8 @@ def _set_up_dashboard_test(d): d.g2 = d.factory.create_group(name='Second') d.ds1 = d.factory.create_data_source() d.ds2 = d.factory.create_data_source() - d.u1 = d.factory.create_user(groups=[d.g1.id]) - d.u2 = d.factory.create_user(groups=[d.g2.id]) + d.u1 = d.factory.create_user(group_ids=[d.g1.id]) + d.u2 = d.factory.create_user(group_ids=[d.g2.id]) models.DataSourceGroup.create(group=d.g1, data_source=d.ds1, permissions=['create', 'view']) models.DataSourceGroup.create(group=d.g2, data_source=d.ds2, permissions=['create', 'view']) d.q1 = d.factory.create_query(data_source=d.ds1) From d2aef544c3bbb4a125e0b5b0490f4efc57857be5 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Mon, 21 Nov 2016 12:35:07 -0600 Subject: [PATCH 02/80] properly handle view_only permission in groups API --- redash/handlers/groups.py | 6 +++--- redash/models.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/redash/handlers/groups.py b/redash/handlers/groups.py index 6ea845acc6..61025fd5f7 100644 --- a/redash/handlers/groups.py +++ b/redash/handlers/groups.py @@ -128,7 +128,7 @@ def post(self, group_id): 'member_id': data_source.id }) - return data_source.to_dict(with_permissions=True) + return data_source.to_dict(with_permissions_for=group) @require_admin def get(self, group_id): @@ -139,7 +139,7 @@ def get(self, group_id): .join(models.DataSourceGroup)\ .where(models.DataSourceGroup.group == group) - return [ds.to_dict(with_permissions=True) for ds in data_sources] + return [ds.to_dict(with_permissions_for=group) for ds in data_sources] class GroupDataSourceResource(BaseResource): @@ -160,7 +160,7 @@ def post(self, group_id, data_source_id): 'view_only': view_only }) - return data_source.to_dict(with_permissions=True) + return data_source.to_dict(with_permissions_for=group) @require_admin def delete(self, group_id, data_source_id): diff --git a/redash/models.py b/redash/models.py index 114891f393..f5d61e3a7a 100644 --- a/redash/models.py +++ b/redash/models.py @@ -371,7 +371,7 @@ class DataSource(BelongsToOrgMixin, db.Model): __tablename__ = 'data_sources' __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) - def to_dict(self, all=False, with_permissions=False): + def to_dict(self, all=False, with_permissions_for=None): d = { 'id': self.id, 'name': self.name, @@ -389,8 +389,10 @@ def to_dict(self, all=False, with_permissions=False): d['scheduled_queue_name'] = self.scheduled_queue_name d['groups'] = self.groups - if with_permissions: - d['view_only'] = self.data_source_groups.view_only + if with_permissions_for is not None: + d['view_only'] = db.session.query(DataSourceGroup.view_only).filter( + DataSourceGroup.group == with_permissions_for, + DataSourceGroup.data_source == self).get() return d @@ -438,17 +440,20 @@ def resume(self): redis_connection.delete(self._pause_key()) def add_group(self, group, view_only=False): - dsg = DataSourceGroup.create(group=group, data_source=self, view_only=view_only) - setattr(self, 'data_source_groups', dsg) + dsg = DataSourceGroup(group=group, data_source=self, view_only=view_only) + db.session.add(dsg) def remove_group(self, group): - DataSourceGroup.delete().where(DataSourceGroup.group==group, DataSourceGroup.data_source==self).execute() + db.session.query(DataSourceGroup).filter( + DataSourceGroup.group == group, + DataSourceGroup.data_source == self).delete() def update_group_permission(self, group, view_only): - dsg = DataSourceGroup.get(DataSourceGroup.group==group, DataSourceGroup.data_source==self) + dsg = db.session.query(DataSourceGroup).filter( + DataSourceGroup.group == group, + DataSourceGroup.data_source == self) dsg.view_only = view_only - dsg.save() - setattr(self, 'data_source_groups', dsg) + db.session.add(dsg) @property def query_runner(self): From ea166665d31c9cbd24a4627af47700e6369d9a5f Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 23 Nov 2016 12:35:18 -0600 Subject: [PATCH 03/80] test_models passes --- redash/models.py | 309 +++++++++++++----------- tests/factories.py | 22 +- tests/tasks/test_refresh_queries.py | 2 +- tests/test_models.py | 355 +++++++++++++++------------- 4 files changed, 367 insertions(+), 321 deletions(-) diff --git a/redash/models.py b/redash/models.py index f5d61e3a7a..7d58cb13d5 100644 --- a/redash/models.py +++ b/redash/models.py @@ -108,17 +108,18 @@ def __setattr__(self, key, value): super(ChangeTrackingMixin, self).__setattr__(key, value) - def record_changes(self, session, changed_by): + def record_changes(self, changed_by): changes = {} for k, v in self._clean_values.iteritems(): if k not in self.skipped_fields: changes[k] = {'previous': v, 'current': getattr(self, k)} - session.add(Change(object_type=self.__class__.__tablename__, - object_id=self.id, + db.session.flush() + db.session.add(Change(object_type=self.__class__.__tablename__, + object=self, object_version=self.version, - user_id=changed_by.id, + user=changed_by, change=changes)) - session.add(self) + class ConflictDetectedError(Exception): @@ -127,7 +128,7 @@ class ConflictDetectedError(Exception): class BelongsToOrgMixin(object): @classmethod def get_by_id_and_org(cls, object_id, org): - return cls.query.filter(cls.id == object_id, cls.org == org).first() + return cls.query.filter(cls.id == object_id, cls.org == org).one_or_none() class PermissionsCheckMixin(object): @@ -180,6 +181,7 @@ class Organization(TimestampMixin, db.Model): name = Column(db.String(255)) slug = Column(db.String(255), unique=True) settings = Column(PseudoJSON) + groups = db.relationship("Group", lazy="dynamic") __tablename__ = 'organizations' @@ -220,7 +222,7 @@ class Group(db.Model, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey('organizations.id')) - org = db.relationship(Organization, backref="groups") + org = db.relationship(Organization, back_populates="groups") type = Column(db.String(255), default=REGULAR_GROUP) name = Column(db.String(100)) permissions = Column(postgresql.ARRAY(db.String(255)), @@ -267,6 +269,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh name = Column(db.String(320)) email = Column(db.String(320)) password_hash = Column(db.String(128), nullable=True) + #XXX replace with association table group_ids = Column('groups', postgresql.ARRAY(db.Integer), nullable=True) api_key = Column(db.String(40), default=lambda: generate_token(40), @@ -338,8 +341,8 @@ def verify_password(self, password): def update_group_assignments(self, group_names): groups = Group.find_by_name(self.org, group_names) groups.append(self.org.default_group) - self.groups = map(lambda g: g.id, groups) - self.save() + self.group_ids = [g.id for g in groups] + db.session.add(self) def has_access(self, obj, access_type): return AccessPermission.exists(obj, access_type, grantee=self) @@ -368,6 +371,7 @@ class DataSource(BelongsToOrgMixin, db.Model): scheduled_queue_name = Column(db.String(255), default="scheduled_queries") created_at = Column(db.DateTime(True), default=db.func.now()) + data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source") __tablename__ = 'data_sources' __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) @@ -468,16 +472,19 @@ def all(cls, org, groups=None): return data_sources + #XXX examine call sites to see if a regular SQLA collection would work better @property def groups(self): - groups = DataSourceGroup.select().where(DataSourceGroup.data_source==self) + groups = db.session.query(DataSourceGroup).filter( + DataSourceGroup.data_source == self) return dict(map(lambda g: (g.group_id, g.view_only), groups)) class DataSourceGroup(db.Model): + #XXX drop id, use datasource/group as PK id = Column(db.Integer, primary_key=True) data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) - data_source = db.relationship(DataSource) + data_source = db.relationship(DataSource, back_populates="data_source_groups") group_id = Column(db.Integer, db.ForeignKey("groups.id")) group = db.relationship(Group, backref="data_sources") view_only = Column(db.Boolean, default=False) @@ -514,8 +521,9 @@ def to_dict(self): def unused(cls, days=7): age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) - unused_results = cls.select().where(Query.id == None, cls.retrieved_at < age_threshold)\ - .join(Query, join_type=peewee.JOIN_LEFT_OUTER) + unused_results = (db.session.query(QueryResult).filter( + Query.id == None, QueryResult.retrieved_at < age_threshold) + .outerjoin(Query)) return unused_results @@ -524,35 +532,41 @@ def get_latest(cls, data_source, query, max_age=0): query_hash = utils.gen_query_hash(query) if max_age == -1: - query = cls.select().where(cls.query_hash == query_hash, - cls.data_source == data_source).order_by(cls.retrieved_at.desc()) + q = db.session.query(QueryResult).filter( + cls.query_hash == query_hash, + cls.data_source == data_source).order_by( + QueryResult.retrieved_at.desc()) else: - query = cls.select().where(cls.query_hash == query_hash, cls.data_source == data_source, - peewee.SQL("retrieved_at at time zone 'utc' + interval '%s second' >= now() at time zone 'utc'", - max_age)).order_by(cls.retrieved_at.desc()) + q = db.session.query(QueryResult).filter( + QueryResult.query_hash == query_hash, + QueryResult.data_source == data_source, + db.func.timezone('utc', QueryResult.retrieved_at) + + datetime.timedelta(seconds=max_age) >= + db.func.timezone('utc', db.func.now()) + ).order_by(QueryResult.retrieved_at.desc()) - return query.first() + return q.first() @classmethod - def store_result(cls, org_id, data_source_id, query_hash, query, data, run_time, retrieved_at): - query_result = cls.create(org=org_id, - query_hash=query_hash, - query=query, - runtime=run_time, - data_source=data_source_id, - retrieved_at=retrieved_at, - data=data) - + def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): + query_result = cls(org=org, + query_hash=query_hash, + query=query, + runtime=run_time, + data_source=data_source, + retrieved_at=retrieved_at, + data=data) + db.session.add(query_result) logging.info("Inserted query (%s) data; id=%s", query_hash, query_result.id) - sql = "UPDATE queries SET latest_query_data_id = %s WHERE query_hash = %s AND data_source_id = %s RETURNING id" - query_ids = [row[0] for row in db.database.execute_sql(sql, params=(query_result.id, query_hash, data_source_id))] - - # TODO: when peewee with update & returning support is released, we can get back to using this code: - # updated_count = Query.update(latest_query_data=query_result).\ - # where(Query.query_hash==query_hash, Query.data_source==data_source_id).\ - # execute() - + # TODO: Investigate how big an impact this select-before-update makes. + queries = db.session.query(Query).filter( + Query.query_hash == query_hash, + Query.data_source == data_source) + for q in queries: + q.latest_query_data = query_result + db.session.add(q) + query_ids = [q.id for q in queries] logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash) return query_result, query_ids @@ -592,8 +606,6 @@ def generate_query_api_key(ctx): str(ctx.current_parameters['user_id']), ctx.current_parameters['name'])).encode('utf-8')).hexdigest() -def gen_query_hash(ctx): - return utils.gen_query_hash(ctx.current_parameters['query']) class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) @@ -607,14 +619,11 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) query = Column(db.Text) - query_hash = Column(db.String(32), - default=gen_query_hash, - onupdate=gen_query_hash) + query_hash = Column(db.String(32)) api_key = Column(db.String(40), default=generate_query_api_key) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, foreign_keys=[user_id]) - last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True, - onupdate=lambda ctx: ctx.current_parameters['user_id']) + last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True) last_modified_by = db.relationship(User, backref="modified_queries", foreign_keys=[last_modified_by_id]) is_archived = Column(db.Boolean, default=False, index=True) @@ -667,45 +676,48 @@ def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, w return d def archive(self, user=None): + db.session.add(self) self.is_archived = True self.schedule = None for vis in self.visualizations: for w in vis.widgets: - w.delete_instance() + db.session.delete(w) - for alert in self.alerts: - alert.delete_instance(recursive=True) + for a in self.alerts: + db.session.delete(a) - self.save(changed_by=user) + if user: + self.record_changes(user) @classmethod def all_queries(cls, groups, drafts=False): - q = Query.select(Query, User, QueryResult.retrieved_at, QueryResult.runtime)\ - .join(QueryResult, join_type=peewee.JOIN_LEFT_OUTER)\ - .switch(Query).join(User)\ - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source))\ - .where(Query.is_archived==False)\ - .where(DataSourceGroup.group << groups)\ - .group_by(Query.id, User.id, QueryResult.id, QueryResult.retrieved_at, QueryResult.runtime)\ - .order_by(cls.created_at.desc()) + q = (db.session.query(Query) + .outerjoin(QueryResult) + .join(User, Query.user_id == User.id) + .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter(Query.is_archived == False) + .filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\ + .group_by(Query.id, User.id, QueryResult.id, QueryResult.retrieved_at, QueryResult.runtime) + .order_by(Query.created_at.desc())) if drafts: - q = q.where(Query.is_draft == True) + q = q.filter(Query.is_draft == True) else: - q = q.where(Query.is_draft == False) + q = q.filter(Query.is_draft == False) + return q @classmethod def by_user(cls, user, drafts): - return cls.all_queries(user.groups, drafts).where(Query.user==user) + return cls.all_queries(user.groups, drafts).filter(Query.user == user) @classmethod def outdated_queries(cls): - queries = cls.select(cls, QueryResult.retrieved_at, DataSource)\ - .join(QueryResult)\ - .switch(Query).join(DataSource)\ - .where(cls.schedule != None) + queries = (db.session.query(Query) + .join(QueryResult) + .join(DataSource) + .filter(Query.schedule != None)) now = utils.utcnow() outdated_queries = {} @@ -719,49 +731,47 @@ def outdated_queries(cls): @classmethod def search(cls, term, groups): # TODO: This is very naive implementation of search, to be replaced with PostgreSQL full-text-search solution. - - where = (cls.name**u"%{}%".format(term)) | (cls.description**u"%{}%".format(term)) + where = (Query.name.like(u"%{}%".format(term)) | + Query.description.like(u"%{}%".format(term))) if term.isdigit(): - where |= cls.id == term + where |= Query.id == term - where &= cls.is_archived == False - - query_ids = cls.select(peewee.fn.Distinct(cls.id))\ - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source)) \ - .where(where) \ - .where(DataSourceGroup.group << groups) - - return cls.select(Query, User).join(User).where(cls.id << query_ids) + where &= Query.is_archived == False + where &= DataSourceGroup.group_id.in_([g.id for g in groups]) + query_ids = ( + db.session.query(Query.id).join( + DataSourceGroup, + Query.data_source_id == DataSourceGroup.data_source_id) + .filter(where)).distinct() + return db.session.query(Query).join(User, Query.user_id == User.id).filter( + Query.id.in_(query_ids)) @classmethod def recent(cls, groups, user_id=None, limit=20): - query = ( - cls.select(Query, User) - .where(Event.created_at > peewee.SQL("current_date - 7")) - .join(Event, on=(Query.id == Event.object_id.cast('integer'))) - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source)) - .switch(Query).join(User) - .where(Event.action << ('edit', 'execute', 'edit_name', - 'edit_description', 'toggle_published', - 'view_source')) - .where(~(Event.object_id >> None)) - .where(Event.object_type == 'query') - .where(DataSourceGroup.group << groups) - .where(cls.is_archived == False) - .where(cls.is_draft == False) - .group_by(Event.object_id, Query.id, User.id) - .order_by(peewee.SQL("count(0) desc"))) + query = (db.session.query(Query).join(User, Query.user_id == User.id) + .filter(Event.created_at > (db.func.current_date() - 7)) + .join(Event, Query.id == Event.object_id.cast(db.Integer)) + .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Event.action.in_(['edit', 'execute', 'edit_name', + 'edit_description', 'view_source']), + Event.object_id != None, + Event.object_type == 'query', + DataSourceGroup.group_id.in_([g.id for g in groups]), + Query.is_draft == False, + Query.is_archived == False) + .group_by(Event.object_id, Query.id, User.id) + .order_by(db.desc(db.func.count(0)))) if user_id: - query = query.where(Event.user == user_id) + query = query.filter(Event.user_id == user_id) query = query.limit(limit) return query -<<<<<<< 196177021c6d0b0ccd50ecbb19ca6fef1ca7160a def fork(self, user): query = self forked_query = Query() @@ -842,14 +852,26 @@ def groups(self): def __unicode__(self): return unicode(self.id) +@listens_for(Query.query, 'set') +def gen_query_hash(target, val, oldval, initiator): + target.query_hash = utils.gen_query_hash(val) + +@listens_for(Query.user_id, 'set') +def query_last_modified_by(target, val, oldval, initiator): + target.last_modified_by_id = val + @listens_for(SignallingSession, 'before_flush') -def create_default_visualizations(session, ctx, *a): +def create_defaults(session, ctx, *a): for obj in session.new: if isinstance(obj, Query): session.add(Visualization(query=obj, name="Table", description='', type="TABLE", options="{}")) +@listens_for(ChangeTrackingMixin, 'init') +def create_first_change(obj, args, kwargs): + obj.record_changes(obj.user) + class AccessPermission(GFKBase, db.Model): @@ -960,6 +982,7 @@ class Alert(TimestampMixin, db.Model): user = db.relationship(User, backref='alerts') options = Column(PseudoJSON) state = Column(db.String(255), default=UNKNOWN_STATE) + subscriptions = db.relationship("AlertSubscription", cascade="delete") last_triggered_at = Column(db.DateTime(True), nullable=True) rearm = Column(db.Integer, nullable=True) @@ -1043,6 +1066,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model name = Column(db.String(100)) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User) + # XXX replace with association table layout = Column(db.Text) dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) @@ -1108,47 +1132,48 @@ def to_dict(self, with_widgets=False, user=None): } @classmethod - def all(cls, org, groups, user_id): - query = (cls.select() - .join(Widget, peewee.JOIN_LEFT_OUTER, - on=(Dashboard.id == Widget.dashboard)) - .join(Visualization, peewee.JOIN_LEFT_OUTER, - on=(Widget.visualization == Visualization.id)) - .join(Query, peewee.JOIN_LEFT_OUTER, - on=(Visualization.query == Query.id)) - .join(DataSourceGroup, peewee.JOIN_LEFT_OUTER, - on=(Query.data_source == DataSourceGroup.data_source)) - .where(Dashboard.is_archived == False) - .where((DataSourceGroup.group << groups & (Dashboard.is_draft != True)) | - (Dashboard.user == user_id) | - (~(Widget.dashboard >> None) & (Widget.visualization >> None))) - .where(Dashboard.org == org) + def all(cls, org, group_ids, user_id): + query = ( + db.session.query(Dashboard) + .outerjoin(Widget) + .outerjoin(Visualization) + .outerjoin(Query) + .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Dashboard.is_archived == False, + (DataSourceGroup.group_id.in_(group_ids) | + (Dashboard.user_id == user_id) | + ((Widget.dashboard != None) & (Widget.visualization == None))), + Dashboard.org == org) .group_by(Dashboard.id)) return query @classmethod - def recent(cls, org, groups, user_id, for_user=False, limit=20): - query = cls.select().where(Event.created_at > peewee.SQL("current_date - 7")). \ - join(Event, peewee.JOIN_LEFT_OUTER, on=(Dashboard.id == Event.object_id.cast('integer'))). \ - join(Widget, peewee.JOIN_LEFT_OUTER, on=(Dashboard.id == Widget.dashboard)). \ - join(Visualization, peewee.JOIN_LEFT_OUTER, on=(Widget.visualization == Visualization.id)). \ - join(Query, peewee.JOIN_LEFT_OUTER, on=(Visualization.query == Query.id)). \ - join(DataSourceGroup, peewee.JOIN_LEFT_OUTER, on=(Query.data_source == DataSourceGroup.data_source)). \ - where(Event.action << ('edit', 'view')). \ - where(~(Event.object_id >> None)). \ - where(Event.object_type == 'dashboard'). \ - where(Dashboard.is_archived == False). \ - where(Dashboard.is_draft == False). \ - where(Dashboard.org == org). \ - where((DataSourceGroup.group << groups) | - (Dashboard.user == user_id) | - (~(Widget.dashboard >> None) & (Widget.visualization >> None))). \ - group_by(Event.object_id, Dashboard.id). \ - order_by(peewee.SQL("count(0) desc")) + def recent(cls, org, group_ids, user_id, for_user=False, limit=20): + query = (db.session.query(Dashboard) + .outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer)) + .outerjoin(Widget) + .outerjoin(Visualization) + .outerjoin(Query) + .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Event.created_at > (db.func.current_date() - 7), + Event.action.in_(['edit', 'view']), + Event.object_id != None, + Event.object_type == 'dashboard', + Dashboard.org == org, + Dashboard.is_draft == False, + Dashboard.is_archived == False, + DataSourceGroup.group_id.in_(group_ids) | + (Dashboard.user_id == user_id) | + ((Widget.dashboard != None) & (Widget.visualization == None))) + .group_by(Event.object_id, Dashboard.id) + .order_by(db.desc(db.func.count(0)))) + if for_user: - query = query.where(Event.user == user_id) + query = query.filter(Event.user_id == user_id) query = query.limit(limit) @@ -1214,7 +1239,7 @@ class Widget(TimestampMixin, db.Model): width = Column(db.Integer) options = Column(db.Text) dashboard_id = Column(db.Integer, db.ForeignKey("dashboards.id"), index=True) - dashboard = db.relationship(Dashboard, backref='widgets') + dashboard = db.relationship(Dashboard) # unused; kept for backward compatability: type = Column(db.String(100), nullable=True) @@ -1245,13 +1270,14 @@ def __unicode__(self): def get_by_id_and_org(cls, widget_id, org): return cls.select(cls, Dashboard).join(Dashboard).where(cls.id == widget_id, Dashboard.org == org).get() - def delete_instance(self, *args, **kwargs): - layout = json.loads(self.dashboard.layout) - layout = map(lambda row: filter(lambda w: w != self.id, row), layout) - layout = filter(lambda row: len(row) > 0, layout) - self.dashboard.layout = json.dumps(layout) - self.dashboard.save() - super(Widget, self).delete_instance(*args, **kwargs) +#XXX produces SQLA warning, replace with association table +@listens_for(Widget, 'before_delete') +def widget_delete(mapper, connection, self): + layout = json.loads(self.dashboard.layout) + layout = map(lambda row: filter(lambda w: w != self.id, row), layout) + layout = filter(lambda row: len(row) > 0, layout) + self.dashboard.layout = json.dumps(layout) + db.session.add(self.dashboard) class Event(db.Model): @@ -1261,6 +1287,7 @@ class Event(db.Model): user_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) user = db.relationship(User, backref="events") action = Column(db.String(255)) + # XXX replace with association table object_type = Column(db.String(255)) object_id = Column(db.String(255), nullable=True) additional_properties = Column(db.Text, nullable=True) @@ -1273,8 +1300,8 @@ def __unicode__(self): @classmethod def record(cls, event): - org = event.pop('org_id') - user = event.pop('user_id', None) + org_id = event.pop('org_id') + user_id = event.pop('user_id', None) action = event.pop('action') object_type = event.pop('object_type') object_id = event.pop('object_id', None) @@ -1282,9 +1309,11 @@ def record(cls, event): created_at = datetime.datetime.utcfromtimestamp(event.pop('timestamp')) additional_properties = json.dumps(event) - event = cls.create(org=org, user=user, action=action, object_type=object_type, object_id=object_id, - additional_properties=additional_properties, created_at=created_at) - + event = cls(org_id=org_id, user_id=user_id, action=action, + object_type=object_type, object_id=object_id, + additional_properties=additional_properties, + created_at=created_at) + db.session.add(event) return event class ApiKey(TimestampMixin, GFKBase, db.Model): @@ -1372,7 +1401,7 @@ class AlertSubscription(TimestampMixin, db.Model): nullable=True) destination = db.relationship(NotificationDestination) alert_id = Column(db.Integer, db.ForeignKey("alerts.id")) - alert = db.relationship(Alert, backref="subscriptions") + alert = db.relationship(Alert, back_populates="subscriptions") __tablename__ = 'alert_subscriptions' __table_args__ = (db.Index('alert_subscriptions_destination_id_alert_id', diff --git a/tests/factories.py b/tests/factories.py index 0a9ec2841d..df242de1fc 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -70,7 +70,7 @@ def __call__(self): is_draft=False, schedule=None, data_source=data_source_factory.create, - org=1) + org_id=1) query_with_params_factory = ModelFactory(redash.models.Query, name='New Query with Params', @@ -81,7 +81,7 @@ def __call__(self): is_draft=False, schedule=None, data_source=data_source_factory.create, - org=1) + org_id=1) access_permission_factory = ModelFactory(redash.models.AccessPermission, object_id=query_factory.create, @@ -103,7 +103,7 @@ def __call__(self): query="SELECT 1", query_hash=gen_query_hash('SELECT 1'), data_source=data_source_factory.create, - org=1) + org_id=1) visualization_factory = ModelFactory(redash.models.Visualization, type='CHART', @@ -120,7 +120,7 @@ def __call__(self): visualization=visualization_factory.create) destination_factory = ModelFactory(redash.models.NotificationDestination, - org=1, + org_id=1, user=user_factory.create, name='Destination', type='slack', @@ -233,21 +233,23 @@ def create_alert_subscription(self, **kwargs): return alert_subscription_factory.create(**args) def create_data_source(self, **kwargs): + group = None + if 'group' in kwargs: + group = kwargs.pop('group') args = { 'org': self.org } args.update(kwargs) - if 'group' in kwargs and 'org' not in kwargs: - args['org'] = kwargs['group'].org + if group and 'org' not in kwargs: + args['org'] = group.org data_source = data_source_factory.create(**args) - if 'group_id' in kwargs: + if group: view_only = kwargs.pop('view_only', False) - db.session.add(redash.models.DataSourceGroup( - group_id=kwargs['group_id'], + group=group, data_source=data_source, view_only=view_only)) @@ -294,7 +296,7 @@ def create_query_result(self, **kwargs): args.update(kwargs) if 'data_source' in args and 'org' not in args: - args['org'] = args['data_source'].org_id + args['org'] = args['data_source'].org return query_result_factory.create(**args) diff --git a/tests/tasks/test_refresh_queries.py b/tests/tasks/test_refresh_queries.py index 2ada084d7c..8927f9c928 100644 --- a/tests/tasks/test_refresh_queries.py +++ b/tests/tasks/test_refresh_queries.py @@ -98,7 +98,7 @@ def test_enqueues_only_for_relevant_data_source(self): query = self.factory.create_query(schedule="60") query2 = self.factory.create_query(schedule="3600", query=query.query, query_hash=query.query_hash) import psycopg2 - retrieved_at = utcnow().replace(tzinfo=psycopg2.tz.FixedOffsetTimezone(offset=0, name=None)) - datetime.timedelta(minutes=10) + retrieved_at = utcnow() - datetime.timedelta(minutes=10) query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, query_hash=query.query_hash) query.latest_query_data = query_result diff --git a/tests/test_models.py b/tests/test_models.py index e1eab88614..ab8eea5c7e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,8 +40,7 @@ def test_search_finds_in_name(self): q1 = self.factory.create_query(name=u"Testing seåřċħ") q2 = self.factory.create_query(name=u"Testing seåřċħing") q3 = self.factory.create_query(name=u"Testing seå řċħ") - db.session.flush() - queries = models.Query.search(u"seåřċħ", [self.factory.default_group]) + queries = list(models.Query.search(u"seåřċħ", [self.factory.default_group])) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -62,7 +61,7 @@ def test_search_by_id_returns_query(self): q1 = self.factory.create_query(description="Testing search") q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") - + db.session.flush() queries = models.Query.search(str(q3.id), [self.factory.default_group]) self.assertIn(q3, queries) @@ -70,25 +69,26 @@ def test_search_by_id_returns_query(self): self.assertNotIn(q2, queries) def test_search_respects_groups(self): - other_group = models.Group.create(org=self.factory.org, name="Other Group") + other_group = models.Group(org=self.factory.org, name="Other Group") + db.session.add(other_group) ds = self.factory.create_data_source(group=other_group) q1 = self.factory.create_query(description="Testing search", data_source=ds) q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") - queries = models.Query.search("Testing", [self.factory.default_group]) + queries = list(models.Query.search("Testing", [self.factory.default_group])) self.assertNotIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = models.Query.search("Testing", [other_group, self.factory.default_group]) + queries = list(models.Query.search("Testing", [other_group, self.factory.default_group])) self.assertIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = models.Query.search("Testing", [other_group]) + queries = list(models.Query.search("Testing", [other_group])) self.assertIn(q1, queries) self.assertNotIn(q2, queries) self.assertNotIn(q3, queries) @@ -100,22 +100,23 @@ def test_returns_each_query_only_once(self): ds.add_group(second_group, False) q1 = self.factory.create_query(description="Testing search", data_source=ds) - + db.session.flush() queries = list(models.Query.search("Testing", [self.factory.default_group, other_group, second_group])) self.assertEqual(1, len(queries)) def test_save_creates_default_visualization(self): q = self.factory.create_query() - self.assertEquals(q.visualizations.count(), 1) + db.session.flush() + self.assertEquals(len(q.visualizations), 1) def test_save_updates_updated_at_field(self): # This should be a test of ModelTimestampsMixin, but it's easier to test in context of existing model... :-\ - one_day_ago = datetime.datetime.today() - datetime.timedelta(days=1) + one_day_ago = utcnow().date() - datetime.timedelta(days=1) q = self.factory.create_query(created_at=one_day_ago, updated_at=one_day_ago) - - q.save() - + db.session.flush() + q.name = 'x' + db.session.flush() self.assertNotEqual(q.updated_at, one_day_ago) @@ -123,12 +124,11 @@ class QueryRecentTest(BaseTestCase): def test_global_recent(self): q1 = self.factory.create_query() q2 = self.factory.create_query() - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - + db.session.flush() + e = models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + db.session.add(e) recent = models.Query.recent([self.factory.default_group]) - self.assertIn(q1, recent) self.assertNotIn(q2, recent) @@ -151,10 +151,10 @@ def test_recent_excludes_drafts(self): def test_recent_for_user(self): q1 = self.factory.create_query() q2 = self.factory.create_query() - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - + db.session.flush() + e = models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + db.session.add(e) recent = models.Query.recent([self.factory.default_group], user_id=self.factory.user.id) self.assertIn(q1, recent) @@ -168,11 +168,11 @@ def test_respects_groups(self): q1 = self.factory.create_query() ds = self.factory.create_data_source(group=self.factory.create_group()) q2 = self.factory.create_query(data_source=ds) - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q2.id) + db.session.flush() + models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q2.id) recent = models.Query.recent([self.factory.default_group]) @@ -182,17 +182,17 @@ def test_respects_groups(self): class ShouldScheduleNextTest(TestCase): def test_interval_schedule_that_needs_reschedule(self): - now = datetime.datetime.now() + now = utcnow() two_hours_ago = now - datetime.timedelta(hours=2) self.assertTrue(models.should_schedule_next(two_hours_ago, now, "3600")) def test_interval_schedule_that_doesnt_need_reschedule(self): - now = datetime.datetime.now() + now = utcnow() half_an_hour_ago = now - datetime.timedelta(minutes=30) self.assertFalse(models.should_schedule_next(half_an_hour_ago, now, "3600")) def test_exact_time_that_needs_reschedule(self): - now = datetime.datetime.now() + now = utcnow() yesterday = now - datetime.timedelta(days=1) scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) @@ -205,7 +205,7 @@ def test_exact_time_that_doesnt_need_reschedule(self): self.assertFalse(models.should_schedule_next(yesterday, now, schedule)) def test_exact_time_with_day_change(self): - now = datetime.datetime.now().replace(hour=0, minute=1) + now = utcnow().replace(hour=0, minute=1) previous = (now - datetime.timedelta(days=2)).replace(hour=23, minute=59) schedule = "23:59".format(now.hour + 3) self.assertTrue(models.should_schedule_next(previous, now, schedule)) @@ -220,21 +220,19 @@ def test_outdated_queries_skips_unscheduled_queries(self): self.assertNotIn(query, queries) def test_outdated_queries_works_with_ttl_based_schedule(self): - two_hours_ago = datetime.datetime.now() - datetime.timedelta(hours=2) + two_hours_ago = utcnow() - datetime.timedelta(hours=2) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query, retrieved_at=two_hours_ago) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=two_hours_ago) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertIn(query, queries) def test_skips_fresh_queries(self): - half_an_hour_ago = datetime.datetime.now() - datetime.timedelta(minutes=30) + half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query, retrieved_at=half_an_hour_ago) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertNotIn(query, queries) @@ -242,9 +240,8 @@ def test_skips_fresh_queries(self): def test_outdated_queries_works_with_specific_time_schedule(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule=half_an_hour_ago.strftime('%H:%M')) - query_result = self.factory.create_query_result(query=query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertIn(query, queries) @@ -256,53 +253,51 @@ def setUp(self): def test_archive_query_sets_flag(self): query = self.factory.create_query() + db.session.flush() query.archive() - query = models.Query.get_by_id(query.id) self.assertEquals(query.is_archived, True) def test_archived_query_doesnt_return_in_all(self): query = self.factory.create_query(schedule="1") - yesterday = datetime.datetime.now() - datetime.timedelta(days=1) - query_result, _ = models.QueryResult.store_result(query.org, query.data_source.id, query.query_hash, query.query, "1", - 123, yesterday) + yesterday = utcnow() - datetime.timedelta(days=1) + query_result, _ = models.QueryResult.store_result( + query.org, query.data_source, query.query_hash, query.query, + "1", 123, yesterday) query.latest_query_data = query_result - query.save() - - self.assertIn(query, list(models.Query.all_queries(query.groups.keys()))) + groups = list(models.Group.query.filter(models.Group.id.in_(query.groups))) + self.assertIn(query, list(models.Query.all_queries(groups))) self.assertIn(query, models.Query.outdated_queries()) - + db.session.flush() query.archive() - self.assertNotIn(query, list(models.Query.all_queries(query.groups.keys()))) + self.assertNotIn(query, list(models.Query.all_queries(groups))) self.assertNotIn(query, models.Query.outdated_queries()) def test_removes_associated_widgets_from_dashboards(self): widget = self.factory.create_widget() query = widget.visualization.query - + db.session.commit() query.archive() - - self.assertRaises(models.Widget.DoesNotExist, models.Widget.get_by_id, widget.id) + db.session.flush() + self.assertEqual(db.session.query(models.Widget).get(widget.id), None) def test_removes_scheduling(self): query = self.factory.create_query(schedule="1") query.archive() - query = models.Query.get_by_id(query.id) - self.assertEqual(None, query.schedule) def test_deletes_alerts(self): subscription = self.factory.create_alert_subscription() query = subscription.alert.query - + db.session.commit() query.archive() - - self.assertRaises(models.Alert.DoesNotExist, models.Alert.get_by_id, subscription.alert.id) - self.assertRaises(models.AlertSubscription.DoesNotExist, models.AlertSubscription.get_by_id, subscription.id) + db.session.flush() + self.assertEqual(db.session.query(models.Alert).get(subscription.alert.id), None) + self.assertEqual(db.session.query(models.AlertSubscription).get(subscription.id), None) class DataSourceTest(BaseTestCase): @@ -355,12 +350,6 @@ def test_get_latest_returns_when_found(self): self.assertEqual(qr, found_query_result) - def test_get_latest_works_with_data_source_id(self): - qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source.id, qr.query, 60) - - self.assertEqual(qr, found_query_result) - def test_get_latest_doesnt_return_query_from_different_data_source(self): qr = self.factory.create_query_result() data_source = self.factory.create_data_source() @@ -369,7 +358,7 @@ def test_get_latest_doesnt_return_query_from_different_data_source(self): self.assertIsNone(found_query_result) def test_get_latest_doesnt_return_if_ttl_expired(self): - yesterday = datetime.datetime.now() - datetime.timedelta(days=1) + yesterday = utcnow() - datetime.timedelta(days=1) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=60) @@ -377,7 +366,7 @@ def test_get_latest_doesnt_return_if_ttl_expired(self): self.assertIsNone(found_query_result) def test_get_latest_returns_if_ttl_not_expired(self): - yesterday = datetime.datetime.now() - datetime.timedelta(seconds=30) + yesterday = utcnow() - datetime.timedelta(seconds=30) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=120) @@ -385,7 +374,7 @@ def test_get_latest_returns_if_ttl_not_expired(self): self.assertEqual(found_query_result, qr) def test_get_latest_returns_the_most_recent_result(self): - yesterday = datetime.datetime.now() - datetime.timedelta(seconds=30) + yesterday = utcnow() - datetime.timedelta(seconds=30) old_qr = self.factory.create_query_result(retrieved_at=yesterday) qr = self.factory.create_query_result() @@ -394,10 +383,10 @@ def test_get_latest_returns_the_most_recent_result(self): self.assertEqual(found_query_result.id, qr.id) def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): - yesterday = datetime.datetime.now() + datetime.timedelta(days=-100) + yesterday = utcnow() + datetime.timedelta(days=-100) very_old = self.factory.create_query_result(retrieved_at=yesterday) - yesterday = datetime.datetime.now() + datetime.timedelta(days=-1) + yesterday = utcnow() + datetime.timedelta(days=-1) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, -1) @@ -406,7 +395,7 @@ def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): class TestUnusedQueryResults(BaseTestCase): def test_returns_only_unused_query_results(self): - two_weeks_ago = datetime.datetime.now() - datetime.timedelta(days=14) + two_weeks_ago = utcnow() - datetime.timedelta(days=14) qr = self.factory.create_query_result() query = self.factory.create_query(latest_query_data=qr) unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) @@ -415,7 +404,7 @@ def test_returns_only_unused_query_results(self): self.assertNotIn(qr, models.QueryResult.unused()) def test_returns_only_over_a_week_old_results(self): - two_weeks_ago = datetime.datetime.now() - datetime.timedelta(days=14) + two_weeks_ago = utcnow() - datetime.timedelta(days=14) unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) new_unused_qr = self.factory.create_query_result() @@ -428,15 +417,21 @@ def test_returns_only_queries_in_given_groups(self): ds1 = self.factory.create_data_source() ds2 = self.factory.create_data_source() - group1 = models.Group.create(name="g1", org=ds1.org) - group2 = models.Group.create(name="g2", org=ds1.org) - - models.DataSourceGroup.create(group=group1, data_source=ds1, permissions=['create', 'view']) - models.DataSourceGroup.create(group=group2, data_source=ds2, permissions=['create', 'view']) + group1 = models.Group(name="g1", org=ds1.org, permissions=['create', 'view']) + group2 = models.Group(name="g2", org=ds1.org, permissions=['create', 'view']) q1 = self.factory.create_query(data_source=ds1) q2 = self.factory.create_query(data_source=ds2) + db.session.add_all([ + ds1, ds2, + group1, group2, + q1, q2, + models.DataSourceGroup( + group=group1, data_source=ds1), + models.DataSourceGroup(group=group2, data_source=ds2) + ]) + db.session.flush() self.assertIn(q1, list(models.Query.all_queries([group1]))) self.assertNotIn(q2, list(models.Query.all_queries([group1]))) self.assertIn(q1, list(models.Query.all_queries([group1, group2]))) @@ -448,14 +443,14 @@ def test_default_group_always_added(self): user = self.factory.create_user() user.update_group_assignments(["g_unknown"]) - self.assertItemsEqual([user.org.default_group.id], user.groups) + self.assertItemsEqual([user.org.default_group.id], user.group_ids) def test_update_group_assignments(self): user = self.factory.user - new_group = models.Group.create(id='999', name="g1", org=user.org) + new_group = models.Group(id=999, name="g1", org=user.org) user.update_group_assignments(["g1"]) - self.assertItemsEqual([user.org.default_group.id, new_group.id], user.groups) + self.assertItemsEqual([user.org.default_group.id, new_group.id], user.group_ids) class TestGroup(BaseTestCase): @@ -463,9 +458,9 @@ def test_returns_groups_with_specified_names(self): org1 = self.factory.create_org() org2 = self.factory.create_org() - matching_group1 = models.Group.create(id='999', name="g1", org=org1) - matching_group2 = models.Group.create(id='888', name="g2", org=org1) - non_matching_group = models.Group.create(id='777', name="g1", org=org2) + matching_group1 = models.Group(id=999, name="g1", org=org1) + matching_group2 = models.Group(id=888, name="g2", org=org1) + non_matching_group = models.Group(id=777, name="g1", org=org2) groups = models.Group.find_by_name(org1, ["g1", "g2"]) self.assertIn(matching_group1, groups) @@ -475,7 +470,7 @@ def test_returns_groups_with_specified_names(self): def test_returns_no_groups(self): org1 = self.factory.create_org() - models.Group.create(id='999', name="g1", org=org1) + models.Group(id=999, name="g1", org=org1) self.assertEqual([], models.Group.find_by_name(org1, ["non-existing"])) @@ -490,9 +485,9 @@ def setUp(self): self.data = "data" def test_stores_the_result(self): - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, - self.data, self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) self.assertEqual(query_result.data, self.data) self.assertEqual(query_result.runtime, self.runtime) @@ -506,39 +501,39 @@ def test_updates_existing_queries(self): query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query) - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_hash(self): query1 = self.factory.create_query(query=self.query) query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query + "123") - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertNotEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertNotEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_data_source(self): query1 = self.factory.create_query(query=self.query) query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query, data_source=self.factory.create_data_source()) - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertNotEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertNotEqual(query3.latest_query_data, query_result) class TestEvents(BaseTestCase): @@ -546,6 +541,7 @@ def raw_event(self): timestamp = 1411778709.791 user = self.factory.user created_at = datetime.datetime.utcfromtimestamp(timestamp) + db.session.flush() raw_event = {"action": "view", "timestamp": timestamp, "object_type": "dashboard", @@ -559,7 +555,7 @@ def test_records_event(self): raw_event, user, created_at = self.raw_event() event = models.Event.record(raw_event) - + db.session.flush() self.assertEqual(event.user, user) self.assertEqual(event.action, "view") self.assertEqual(event.object_type, "dashboard") @@ -580,32 +576,36 @@ class TestWidgetDeleteInstance(BaseTestCase): def test_delete_removes_from_layout(self): widget = self.factory.create_widget() widget2 = self.factory.create_widget(dashboard=widget.dashboard) + db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) - widget.dashboard.save() - widget.delete_instance() - + db.session.delete(widget) + db.session.flush() self.assertEquals(json.dumps([[widget2.id]]), widget.dashboard.layout) def test_delete_removes_empty_rows(self): widget = self.factory.create_widget() widget2 = self.factory.create_widget(dashboard=widget.dashboard) + db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) - widget.dashboard.save() - widget.delete_instance() - widget2.delete_instance() - + db.session.flush() + db.session.delete(widget) + db.session.delete(widget2) + db.session.flush() self.assertEquals("[]", widget.dashboard.layout) def _set_up_dashboard_test(d): - d.g1 = d.factory.create_group(name='First') - d.g2 = d.factory.create_group(name='Second') + d.g1 = d.factory.create_group(name='First', permissions=['create', 'view']) + d.g2 = d.factory.create_group(name='Second', permissions=['create', 'view']) d.ds1 = d.factory.create_data_source() d.ds2 = d.factory.create_data_source() + db.session.flush() d.u1 = d.factory.create_user(group_ids=[d.g1.id]) d.u2 = d.factory.create_user(group_ids=[d.g2.id]) - models.DataSourceGroup.create(group=d.g1, data_source=d.ds1, permissions=['create', 'view']) - models.DataSourceGroup.create(group=d.g2, data_source=d.ds2, permissions=['create', 'view']) + db.session.add_all([ + models.DataSourceGroup(group=d.g1, data_source=d.ds1), + models.DataSourceGroup(group=d.g2, data_source=d.ds2) + ]) d.q1 = d.factory.create_query(data_source=d.ds1) d.q2 = d.factory.create_query(data_source=d.ds2) d.v1 = d.factory.create_visualization(query=d.q1) @@ -624,43 +624,49 @@ def setUp(self): def test_requires_group_or_user_id(self): d1 = self.factory.create_dashboard() - - self.assertNotIn(d1, models.Dashboard.all(d1.user.org, d1.user.groups, None)) - self.assertIn(d1, models.Dashboard.all(d1.user.org, [0], d1.user.id)) + self.assertNotIn(d1, list(models.Dashboard.all( + d1.user.org, d1.user.group_ids, None))) + l2 = list(models.Dashboard.all( + d1.user.org, [0], d1.user.id)) + self.assertIn(d1, l2) def test_returns_dashboards_based_on_groups(self): - self.assertIn(self.w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertIn(self.w2.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) + self.assertIn(self.w1.dashboard, list(models.Dashboard.all( + self.u1.org, self.u1.group_ids, None))) + self.assertIn(self.w2.dashboard, list(models.Dashboard.all( + self.u2.org, self.u2.group_ids, None))) + self.assertNotIn(self.w1.dashboard, list(models.Dashboard.all( + self.u2.org, self.u2.group_ids, None))) + self.assertNotIn(self.w2.dashboard, list(models.Dashboard.all( + self.u1.org, self.u1.group_ids, None))) def test_returns_each_dashboard_once(self): - dashboards = list(models.Dashboard.all(self.u2.org, self.u2.groups, None)) + dashboards = list(models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) self.assertEqual(len(dashboards), 2) def test_returns_dashboard_you_have_partial_access_to(self): - self.assertIn(self.w5.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) + self.assertIn(self.w5.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) def test_returns_dashboards_created_by_user(self): d1 = self.factory.create_dashboard(user=self.u1) - - self.assertIn(d1, models.Dashboard.all(self.u1.org, self.u1.groups, self.u1.id)) - self.assertIn(d1, models.Dashboard.all(self.u1.org, [0], self.u1.id)) - self.assertNotIn(d1, models.Dashboard.all(self.u2.org, self.u2.groups, self.u2.id)) + db.session.flush() + self.assertIn(d1, list(models.Dashboard.all(self.u1.org, self.u1.group_ids, self.u1.id))) + self.assertIn(d1, list(models.Dashboard.all(self.u1.org, [0], self.u1.id))) + self.assertNotIn(d1, list(models.Dashboard.all(self.u2.org, self.u2.group_ids, self.u2.id))) def test_returns_dashboards_with_text_widgets(self): w1 = self.factory.create_widget(visualization=None) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) user = self.factory.create_user(org=self.factory.create_org()) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.group_ids, None)) class TestDashboardRecent(BaseTestCase): @@ -669,12 +675,12 @@ def setUp(self): _set_up_dashboard_test(self) def test_returns_recent_dashboards_basic(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - - self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u2.groups, None)) + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id)) + self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u2.group_ids, None)) def test_recent_excludes_drafts(self): models.Event.create(org=self.factory.org, user=self.u1, action="view", @@ -688,53 +694,62 @@ def test_recent_excludes_drafts(self): def test_returns_recent_dashboards_created_by_user(self): d1 = self.factory.create_dashboard(user=self.u1) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=d1.id) - + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=d1.id)) self.assertIn(d1, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(d1, models.Dashboard.recent(self.u2.org, [0], self.u2.id)) def test_returns_recent_dashboards_with_no_visualizations(self): w1 = self.factory.create_widget(visualization=None) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=w1.dashboard.id) - + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=w1.dashboard.id)) + db.session.flush() self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) def test_restricts_dashboards_for_user(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u2, action="view", - object_type="dashboard", object_id=self.w2.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w5.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u2, action="view", - object_type="dashboard", object_id=self.w5.dashboard.id) - - self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w2.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) + db.session.flush() + db.session.add_all([ + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id), + models.Event(org=self.factory.org, user=self.u2, action="view", + object_type="dashboard", object_id=self.w2.dashboard.id), + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w5.dashboard.id), + models.Event(org=self.factory.org, user=self.u2, action="view", + object_type="dashboard", object_id=self.w5.dashboard.id) + ]) + db.session.flush() + self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w2.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) + self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) + self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) def test_returns_each_dashboard_once(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - - dashboards = list(models.Dashboard.recent(self.u1.org, self.u1.groups, None)) + db.session.flush() + db.session.add_all([ + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id), + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id) + ]) + db.session.flush() + dashboards = list(models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) self.assertEqual(len(dashboards), 1) def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=w1.dashboard.id) - + db.session.flush() + db.session.add(models.Event( + org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=w1.dashboard.id)) + db.session.flush() user = self.factory.create_user(org=self.factory.create_org()) - self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(w1.dashboard, models.Dashboard.recent(user.org, user.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(w1.dashboard, models.Dashboard.recent(user.org, user.group_ids, None)) From f00d77dec477db48f3fac493def8740eadd25171 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Mon, 28 Nov 2016 08:41:29 -0600 Subject: [PATCH 04/80] auth tests wip --- redash/authentication/__init__.py | 20 +++++++----- redash/authentication/google_oauth.py | 7 ++-- redash/handlers/base.py | 2 +- redash/handlers/queries.py | 3 +- redash/metrics/request.py | 4 +-- redash/models.py | 47 ++++++++++++++------------- tests/factories.py | 8 ++--- tests/handlers/__init__.py | 3 ++ tests/test_authentication.py | 19 +++++------ 9 files changed, 62 insertions(+), 51 deletions(-) diff --git a/redash/authentication/__init__.py b/redash/authentication/__init__.py index 788ca4e3ca..553669fd4b 100644 --- a/redash/authentication/__init__.py +++ b/redash/authentication/__init__.py @@ -5,6 +5,7 @@ import logging from flask import redirect, request, jsonify, url_for +from sqlalchemy.orm.exc import NoResultFound from redash import models, settings from redash.authentication.org_resolving import current_org @@ -36,9 +37,10 @@ def sign(key, path, expires): @login_manager.user_loader def load_user(user_id): + org = current_org._get_current_object() try: - return models.User.get_by_id_and_org(user_id, current_org.id) - except models.User.DoesNotExist: + return models.User.get_by_id_and_org(user_id, org) + except NoResultFound: return None @@ -51,14 +53,14 @@ def hmac_load_user_from_request(request): # TODO: 3600 should be a setting if signature and time.time() < expires <= time.time() + 3600: if user_id: - user = models.User.get_by_id(user_id) + user = models.User.query.get(user_id) calculated_signature = sign(user.api_key, request.path, expires) if user.api_key and signature == calculated_signature: return user if query_id: - query = models.Query.get(models.Query.id == query_id) + query = models.db.session.query(models.Query).filter(models.Query.id == query_id).one() calculated_signature = sign(query.api_key, request.path, expires) if query.api_key and signature == calculated_signature: @@ -74,15 +76,16 @@ def get_user_from_api_key(api_key, query_id): user = None # TODO: once we switch all api key storage into the ApiKey model, this code will be much simplified + org = current_org._get_current_object() try: - user = models.User.get_by_api_key_and_org(api_key, current_org.id) - except models.User.DoesNotExist: + user = models.User.get_by_api_key_and_org(api_key, org) + except NoResultFound: try: api_key = models.ApiKey.get_by_api_key(api_key) user = models.ApiUser(api_key, api_key.org, []) - except models.ApiKey.DoesNotExist: + except NoResultFound: if query_id: - query = models.Query.get_by_id_and_org(query_id, current_org.id) + query = models.Query.get_by_id_and_org(query_id, org) if query and query.api_key == api_key: user = models.ApiUser(api_key, query.org, query.groups.keys(), name="ApiKey: Query {}".format(query.id)) @@ -105,7 +108,6 @@ def get_api_key_from_request(request): def api_key_load_user_from_request(request): api_key = get_api_key_from_request(request) query_id = request.view_args.get('query_id', None) - user = get_user_from_api_key(api_key, query_id) return user diff --git a/redash/authentication/google_oauth.py b/redash/authentication/google_oauth.py index bfda933313..c02f467f47 100644 --- a/redash/authentication/google_oauth.py +++ b/redash/authentication/google_oauth.py @@ -3,6 +3,8 @@ from flask import redirect, url_for, Blueprint, flash, request, session from flask_login import login_user from flask_oauthlib.client import OAuth +from sqlalchemy.orm.exc import NoResultFound + from redash import models, settings from redash.authentication.org_resolving import current_org @@ -63,9 +65,10 @@ def create_and_login_user(org, name, email): logger.debug("Updating user name (%r -> %r)", user_object.name, name) user_object.name = name user_object.save() - except models.User.DoesNotExist: + except NoResultFound: logger.debug("Creating user object (%r)", name) - user_object = models.User.create(org=org, name=name, email=email, group_ids=[org.default_group.id]) + user_object = models.User(org=org, name=name, email=email, group_ids=[org.default_group.id]) + models.db.session.add(user_object) login_user(user_object, remember=True) diff --git a/redash/handlers/base.py b/redash/handlers/base.py index d213a12d68..465b7480fe 100644 --- a/redash/handlers/base.py +++ b/redash/handlers/base.py @@ -93,7 +93,7 @@ def paginate(query_set, page, page_size, serializer): 'count': count, 'page': page, 'page_size': page_size, - 'results': [serializer(result) for result in results], + 'results': [serializer(result) for result in results.items], } diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index 764ea8b94b..1a781f94eb 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -76,7 +76,8 @@ def post(self): @require_permission('view_query') def get(self): - results = models.Query.all_queries(self.current_user.groups) + results = models.Query.all_queries([models.Group.query.get(g_id) + for g_id in self.current_user.group_ids]) page = request.args.get('page', 1, type=int) page_size = request.args.get('page_size', 25, type=int) return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False)) diff --git a/redash/metrics/request.py b/redash/metrics/request.py index fce2fdd9b3..82ee0a6459 100644 --- a/redash/metrics/request.py +++ b/redash/metrics/request.py @@ -27,8 +27,8 @@ def calculate_metrics(response): response.content_type, response.content_length, request_duration, - db.database.query_count, - db.database.query_duration) + # XXX instrument SQLA for metrics + None, None) statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration) diff --git a/redash/models.py b/redash/models.py index 7d58cb13d5..2270195411 100644 --- a/redash/models.py +++ b/redash/models.py @@ -128,7 +128,7 @@ class ConflictDetectedError(Exception): class BelongsToOrgMixin(object): @classmethod def get_by_id_and_org(cls, object_id, org): - return cls.query.filter(cls.id == object_id, cls.org == org).one_or_none() + return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one_or_none() class PermissionsCheckMixin(object): @@ -265,7 +265,7 @@ def create_group_hack(*a, **kw): class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey('organizations.id')) - org = db.relationship(Organization, backref="users") + org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic")) name = Column(db.String(320)) email = Column(db.String(320)) password_hash = Column(db.String(128), nullable=True) @@ -287,7 +287,7 @@ def to_dict(self, with_api_key=False): 'name': self.name, 'email': self.email, 'gravatar_url': self.gravatar_url, - 'groups': self.groups, + 'groups': self.group_ids, 'updated_at': self.updated_at, 'created_at': self.created_at } @@ -311,15 +311,15 @@ def gravatar_url(self): def permissions(self): # TODO: this should be cached. return list(itertools.chain(*[g.permissions for g in - Group.select().where(Group.id << self.groups)])) + Group.query.filter(Group.id.in_(self.group_ids))])) @classmethod def get_by_email_and_org(cls, email, org): - return cls.get(cls.email == email, cls.org == org) + return cls.query.filter(cls.email == email, cls.org == org).one() @classmethod def get_by_api_key_and_org(cls, api_key, org): - return cls.get(cls.api_key == api_key, cls.org == org) + return cls.query.filter(cls.api_key == api_key, cls.org == org).one() @classmethod def all(cls, org): @@ -566,7 +566,7 @@ def store_result(cls, org, data_source, query_hash, query, data, run_time, retri for q in queries: q.latest_query_data = query_result db.session.add(q) - query_ids = [q.id for q in queries] + query_ids = [q.id for q in queries] logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash) return query_result, query_ids @@ -618,7 +618,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): latest_query_data = db.relationship(QueryResult) name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) - query = Column(db.Text) + query_text = Column("query", db.Text) query_hash = Column(db.String(32)) api_key = Column(db.String(40), default=generate_query_api_key) user_id = Column(db.Integer, db.ForeignKey("users.id")) @@ -639,10 +639,10 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True): d = { 'id': self.id, - 'latest_query_data_id': self._data.get('latest_query_data', None), + 'latest_query_data_id': self.latest_query_data, 'name': self.name, 'description': self.description, - 'query': self.query, + 'query': self.query_text, 'query_hash': self.query_hash, 'schedule': self.schedule, 'api_key': self.api_key, @@ -666,8 +666,12 @@ def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, w d['last_modified_by_id'] = self.last_modified_by_id if with_stats: - d['retrieved_at'] = self.retrieved_at - d['runtime'] = self.runtime + if self.latest_query_data is not None: + d['retrieved_at'] = self.retrieved_at + d['runtime'] = self.runtime + else: + d['retrieved_at'] = None + d['runtime'] = None if with_visualizations: d['visualizations'] = [vis.to_dict(with_query=False) @@ -692,9 +696,8 @@ def archive(self, user=None): @classmethod def all_queries(cls, groups, drafts=False): - q = (db.session.query(Query) + q = (cls.query.join(User, Query.user_id == User.id) .outerjoin(QueryResult) - .join(User, Query.user_id == User.id) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .filter(Query.is_archived == False) .filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\ @@ -714,7 +717,7 @@ def by_user(cls, user, drafts): @classmethod def outdated_queries(cls): - queries = (db.session.query(Query) + queries = (cls.query(Query) .join(QueryResult) .join(DataSource) .filter(Query.schedule != None)) @@ -740,7 +743,7 @@ def search(cls, term, groups): where &= Query.is_archived == False where &= DataSourceGroup.group_id.in_([g.id for g in groups]) query_ids = ( - db.session.query(Query.id).join( + cls.query(Query.id).join( DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .filter(where)).distinct() @@ -750,7 +753,7 @@ def search(cls, term, groups): @classmethod def recent(cls, groups, user_id=None, limit=20): - query = (db.session.query(Query).join(User, Query.user_id == User.id) + query = (cls.query(Query).join(User, Query.user_id == User.id) .filter(Event.created_at > (db.func.current_date() - 7)) .join(Event, Query.id == Event.object_id.cast(db.Integer)) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) @@ -852,7 +855,7 @@ def groups(self): def __unicode__(self): return unicode(self.id) -@listens_for(Query.query, 'set') +@listens_for(Query.query_text, 'set') def gen_query_hash(target, val, oldval, initiator): target.query_hash = utils.gen_query_hash(val) @@ -1051,7 +1054,7 @@ def groups(self): def generate_slug(ctx): slug = utils.slugify(ctx.current_parameters['name']) tries = 1 - while db.session.query(Dashboard).filter(Dashboard.slug == slug).first() is not None: + while Dashboard.query.filter(Dashboard.slug == slug).first() is not None: slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries) tries += 1 return slug @@ -1134,7 +1137,7 @@ def to_dict(self, with_widgets=False, user=None): @classmethod def all(cls, org, group_ids, user_id): query = ( - db.session.query(Dashboard) + Dashboard.query .outerjoin(Widget) .outerjoin(Visualization) .outerjoin(Query) @@ -1151,7 +1154,7 @@ def all(cls, org, group_ids, user_id): @classmethod def recent(cls, org, group_ids, user_id, for_user=False, limit=20): - query = (db.session.query(Dashboard) + query = (Dashboard.query .outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer)) .outerjoin(Widget) .outerjoin(Visualization) @@ -1331,7 +1334,7 @@ class ApiKey(TimestampMixin, GFKBase, db.Model): @classmethod def get_by_api_key(cls, api_key): - return cls.get(cls.api_key==api_key, cls.active==True) + return cls.query.filter(cls.api_key==api_key, cls.active==True).one() @classmethod def get_by_object(cls, object): diff --git a/tests/factories.py b/tests/factories.py index df242de1fc..f23f427965 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -64,7 +64,7 @@ def __call__(self): query_factory = ModelFactory(redash.models.Query, name='Query', description='', - query='SELECT 1', + query_text='SELECT 1', user=user_factory.create, is_archived=False, is_draft=False, @@ -75,7 +75,7 @@ def __call__(self): query_with_params_factory = ModelFactory(redash.models.Query, name='New Query with Params', description='', - query='SELECT {{param1}}', + query_text='SELECT {{param1}}', user=user_factory.create, is_archived=False, is_draft=False, @@ -100,14 +100,14 @@ def __call__(self): data='{"columns":{}, "rows":[]}', runtime=1, retrieved_at=utcnow, - query="SELECT 1", + query_text="SELECT 1", query_hash=gen_query_hash('SELECT 1'), data_source=data_source_factory.create, org_id=1) visualization_factory = ModelFactory(redash.models.Visualization, type='CHART', - query=query_factory.create, + query_text=query_factory.create, name='Chart', description='', options='{}') diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py index 748194b44e..3e50654083 100644 --- a/tests/handlers/__init__.py +++ b/tests/handlers/__init__.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from tests.factories import user_factory +from redash.models import db from redash.utils import json_dumps from redash.wsgi import app @@ -10,6 +11,8 @@ def authenticate_request(c, user): with c.session_transaction() as sess: + if user.id is None: + db.session.flush() sess['user_id'] = user.id diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 231b830820..1b10652353 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -16,8 +16,9 @@ class TestApiKeyAuthentication(BaseTestCase): # def setUp(self): super(TestApiKeyAuthentication, self).setUp() - self.api_key = 10 + self.api_key = '10' self.query = self.factory.create_query(api_key=self.api_key) + models.db.session.flush() self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id) self.queries_url = '/{}/api/queries'.format(self.factory.org.slug) @@ -43,6 +44,7 @@ def test_no_query_id(self): def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") + models.db.session.flush() with app.test_client() as c: rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) self.assertEqual(user.id, api_key_load_user_from_request(request).id) @@ -71,8 +73,9 @@ class TestHMACAuthentication(BaseTestCase): # def setUp(self): super(TestHMACAuthentication, self).setUp() - self.api_key = 10 + self.api_key = '10' self.query = self.factory.create_query(api_key=self.api_key) + models.db.session.flush() self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id) self.expires = time.time() + 1800 @@ -102,10 +105,11 @@ def test_no_query_id(self): def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") path = '/api/queries/' + models.db.session.flush() with app.test_client() as c: signature = sign(user.api_key, path, self.expires) rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) - self.assertEqual(user.id, hmac_load_user_from_request(request).id) + self.assertEqual(user, hmac_load_user_from_request(request)) class TestCreateAndLoginUser(BaseTestCase): @@ -124,8 +128,8 @@ def test_creates_vaild_new_user(self): create_and_login_user(self.factory.org, name, email) self.assertTrue(login_user_mock.called) - user = models.User.get(models.User.email == email) - + user = models.User.query.filter(models.User.email == email).one() + self.assertEqual(user.email, email) class TestVerifyProfile(BaseTestCase): def test_no_domain_allowed_for_org(self): @@ -135,29 +139,24 @@ def test_no_domain_allowed_for_org(self): def test_domain_not_in_org_domains_list(self): profile = dict(email='arik@example.com') self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] - self.factory.org.save() self.assertFalse(verify_profile(self.factory.org, profile)) def test_domain_in_org_domains_list(self): profile = dict(email='arik@example.com') self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com'] - self.factory.org.save() self.assertTrue(verify_profile(self.factory.org, profile)) self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com'] - self.factory.org.save() self.assertTrue(verify_profile(self.factory.org, profile)) def test_org_in_public_mode_accepts_any_domain(self): profile = dict(email='arik@example.com') self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [] - self.factory.org.save() self.assertTrue(verify_profile(self.factory.org, profile)) def test_user_not_in_domain_but_account_exists(self): profile = dict(email='arik@example.com') self.factory.create_user(email='arik@example.com') self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] - self.factory.org.save() self.assertTrue(verify_profile(self.factory.org, profile)) From 982667ffa93870f30b2a0e98a7cf6395138baea1 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Fri, 14 Oct 2016 17:25:30 -0500 Subject: [PATCH 05/80] Make draft status for queries and dashboards toggleable. --- redash/handlers/dashboards.py | 10 +++++----- redash/models.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redash/handlers/dashboards.py b/redash/handlers/dashboards.py index 34c1500917..0fcb8eea17 100644 --- a/redash/handlers/dashboards.py +++ b/redash/handlers/dashboards.py @@ -32,11 +32,11 @@ def get(self): @require_permission('create_dashboard') def post(self): dashboard_properties = request.get_json(force=True) - dashboard = models.Dashboard.create(name=dashboard_properties['name'], - org=self.current_org, - user=self.current_user, - is_draft=True, - layout='[]') + dashboard = models.Dashboard(name=dashboard_properties['name'], + org=self.current_org, + user=self.current_user, + is_draft=True, + layout='[]') return dashboard.to_dict() diff --git a/redash/models.py b/redash/models.py index 2270195411..a5f83d6eda 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1168,6 +1168,7 @@ def recent(cls, org, group_ids, user_id, for_user=False, limit=20): Dashboard.org == org, Dashboard.is_draft == False, Dashboard.is_archived == False, + Dashboard.is_draft == False, DataSourceGroup.group_id.in_(group_ids) | (Dashboard.user_id == user_id) | ((Widget.dashboard != None) & (Widget.visualization == None))) From f55b836896a9d29a4e9f51f957ed98980a581ecf Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 11:03:27 +0200 Subject: [PATCH 06/80] Fix: fix database URL --- redash/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redash/settings.py b/redash/settings.py index f03add2ac5..2129db7443 100644 --- a/redash/settings.py +++ b/redash/settings.py @@ -65,7 +65,7 @@ def all_settings(): STATSD_USE_TAGS = parse_boolean(os.environ.get('REDASH_STATSD_USE_TAGS', "false")) # Connection settings for re:dash's own database (where we store the queries, results, etc) -SQLALCHEMY_DATABASE_URI = os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql://postgres")) +SQLALCHEMY_DATABASE_URI = os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql:///postgres")) SQLALCHEMY_TRACK_MODIFICATIONS = False # Celery related settings From b390cd2e3d459ccf69dee45437cf95e07d1a2a80 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 11:21:45 +0200 Subject: [PATCH 07/80] Close DB connection between tests. Otherwise we were running out of connections. --- tests/__init__.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 7bd5dac677..dd777490d0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,8 @@ import os +import logging +import datetime +from unittest import TestCase + os.environ['REDASH_REDIS_URL'] = "redis://localhost:6379/5" # Use different url for Celery to avoid DB being cleaned up: os.environ['REDASH_CELERY_BROKER'] = "redis://localhost:6379/6" @@ -8,19 +12,10 @@ os.environ['REDASH_GOOGLE_CLIENT_SECRET'] = "dummy" os.environ['REDASH_MULTI_ORG'] = "true" -import logging -from unittest import TestCase -import datetime -from redash import create_app, settings -from factories import Factory - -settings.DATABASE_CONFIG = { - 'name': 'circle_test', - 'threadlocals': True -} - -from redash import redis_connection import redash.models +from redash import create_app +from redash import redis_connection +from factories import Factory from tests.handlers import make_request logging.disable("INFO") @@ -29,13 +24,13 @@ class BaseTestCase(TestCase): def setUp(self): + redash.models.db.session.close() self.app = create_app() self.app_ctx = self.app.app_context() self.app_ctx.push() redash.models.create_db(True, True) self.factory = Factory() - def tearDown(self): redash.models.create_db(False, True) self.app_ctx.pop() From 2bff12b376b57534f92ff14dc0731744cfca21ed Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 12:16:32 +0200 Subject: [PATCH 08/80] Update all tests to use the same test_client --- redash/metrics/request.py | 13 ++-- redash/models.py | 2 - tests/__init__.py | 59 ++++++++++++++- tests/handlers/__init__.py | 80 -------------------- tests/handlers/test_authentication.py | 19 +++-- tests/handlers/test_users.py | 4 +- tests/test_authentication.py | 78 ++++++++----------- tests/test_handlers.py | 104 ++++++++++++-------------- 8 files changed, 152 insertions(+), 207 deletions(-) diff --git a/redash/metrics/request.py b/redash/metrics/request.py index 82ee0a6459..b6ceada970 100644 --- a/redash/metrics/request.py +++ b/redash/metrics/request.py @@ -1,8 +1,8 @@ -from collections import namedtuple -import time import logging +import time +from collections import namedtuple -from flask import request, g +from flask import g, request from redash import statsd_client from redash.models import db @@ -43,6 +43,7 @@ def calculate_metrics_on_exception(error): def provision_app(app): - app.before_request(record_requets_start_time) - app.after_request(calculate_metrics) - app.teardown_request(calculate_metrics_on_exception) + # app.before_request(record_requets_start_time) + # app.after_request(calculate_metrics) + # app.teardown_request(calculate_metrics_on_exception) + pass diff --git a/redash/models.py b/redash/models.py index a5f83d6eda..a6fb5701da 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1470,8 +1470,6 @@ def to_dict(self): _gfk_types = {'queries': Query, 'dashboards': Dashboard} -all_models = (Organization, Group, DataSource, DataSourceGroup, User, QueryResult, Query, Alert, Dashboard, Visualization, Widget, Event, NotificationDestination, AlertSubscription, ApiKey, AccessPermission, Change) - def init_db(): default_org = Organization(name="Default", slug='default', settings={}) diff --git a/tests/__init__.py b/tests/__init__.py index dd777490d0..a19f59bad4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,9 @@ import os import logging import datetime +import json from unittest import TestCase +from contextlib import contextmanager os.environ['REDASH_REDIS_URL'] = "redis://localhost:6379/5" # Use different url for Celery to avoid DB being cleaned up: @@ -15,19 +17,37 @@ import redash.models from redash import create_app from redash import redis_connection -from factories import Factory -from tests.handlers import make_request +from redash.utils import json_dumps +from tests.factories import Factory, user_factory + logging.disable("INFO") logging.getLogger("metrics").setLevel("ERROR") +def authenticate_request(c, user): + with c.session_transaction() as sess: + sess['user_id'] = user.id + + +@contextmanager +def authenticated_user(c, user=None): + if not user: + user = user_factory.create() + + authenticate_request(c, user) + + yield user + + class BaseTestCase(TestCase): def setUp(self): - redash.models.db.session.close() self.app = create_app() + self.app.config['TESTING'] = True + self.client = self.app.test_client() self.app_ctx = self.app.app_context() self.app_ctx.push() + redash.models.db.session.remove() redash.models.create_db(True, True) self.factory = Factory() @@ -46,7 +66,38 @@ def make_request(self, method, path, org=None, user=None, data=None, is_json=Tru if org is not False: path = "/{}{}".format(org.slug, path) - return make_request(method, path, user, data, is_json) + if user: + authenticate_request(self.client, user) + + method_fn = getattr(self.client, method.lower()) + headers = {} + + if data and is_json: + data = json_dumps(data) + + if is_json: + content_type = 'application/json' + else: + content_type = None + + response = method_fn(path, data=data, headers=headers, content_type=content_type) + + if response.data and is_json: + response.json = json.loads(response.data) + + return response + + def get_request(self, path, org=None): + if org: + path = "/{}{}".format(org.slug, path) + + return self.client.get(path) + + def post_request(self, path, data=None, org=None): + if org: + path = "/{}{}".format(org.slug, path) + + return self.client.post(path, data=data) def assertResponseEqual(self, expected, actual): for k, v in expected.iteritems(): diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py index 3e50654083..e69de29bb2 100644 --- a/tests/handlers/__init__.py +++ b/tests/handlers/__init__.py @@ -1,80 +0,0 @@ -import json -from contextlib import contextmanager - -from tests.factories import user_factory -from redash.models import db -from redash.utils import json_dumps -from redash.wsgi import app - -app.config['TESTING'] = True - - -def authenticate_request(c, user): - with c.session_transaction() as sess: - if user.id is None: - db.session.flush() - sess['user_id'] = user.id - - -@contextmanager -def authenticated_user(c, user=None): - if not user: - user = user_factory.create() - - authenticate_request(c, user) - - yield user - - -def json_request(method, path, data=None): - if data: - response = method(path, data=json_dumps(data)) - else: - response = method(path) - - if response.data: - response.json = json.loads(response.data) - else: - response.json = None - - return response - - -def make_request(method, path, user, data=None, is_json=True): - with app.test_client() as c: - if user: - authenticate_request(c, user) - - method_fn = getattr(c, method.lower()) - headers = {} - - if data and is_json: - data = json_dumps(data) - - if is_json: - content_type = 'application/json' - else: - content_type = None - - response = method_fn(path, data=data, headers=headers, content_type=content_type) - - if response.data and is_json: - response.json = json.loads(response.data) - - return response - - -def get_request(path, org=None): - if org: - path = "/{}{}".format(org.slug, path) - - with app.test_client() as c: - return c.get(path) - - -def post_request(path, data=None, org=None): - if org: - path = "/{}{}".format(org.slug, path) - - with app.test_client() as c: - return c.post(path, data=data) diff --git a/tests/handlers/test_authentication.py b/tests/handlers/test_authentication.py index 8896aee468..c7989e5d89 100644 --- a/tests/handlers/test_authentication.py +++ b/tests/handlers/test_authentication.py @@ -4,7 +4,6 @@ from redash import settings from redash.models import User from redash.authentication.account import invite_token -from tests.handlers import get_request, post_request class TestInvite(BaseTestCase): @@ -14,16 +13,16 @@ def test_expired_invite_token(self): patched_time.return_value = time.time() - (7 * 24 * 3600) - 10 token = invite_token(self.factory.user) - response = get_request('/invite/{}'.format(token), org=self.factory.org) + response = self.get_request('/invite/{}'.format(token), org=self.factory.org) self.assertEqual(response.status_code, 400) def test_invalid_invite_token(self): - response = get_request('/invite/badtoken', org=self.factory.org) + response = self.get_request('/invite/badtoken', org=self.factory.org) self.assertEqual(response.status_code, 400) def test_valid_token(self): token = invite_token(self.factory.user) - response = get_request('/invite/{}'.format(token), org=self.factory.org) + response = self.get_request('/invite/{}'.format(token), org=self.factory.org) self.assertEqual(response.status_code, 200) def test_already_active_user(self): @@ -33,16 +32,16 @@ def test_already_active_user(self): class TestInvitePost(BaseTestCase): def test_empty_password(self): token = invite_token(self.factory.user) - response = post_request('/invite/{}'.format(token), data={'password': ''}, org=self.factory.org) + response = self.post_request('/invite/{}'.format(token), data={'password': ''}, org=self.factory.org) self.assertEqual(response.status_code, 400) def test_invalid_password(self): token = invite_token(self.factory.user) - response = post_request('/invite/{}'.format(token), data={'password': '1234'}, org=self.factory.org) + response = self.post_request('/invite/{}'.format(token), data={'password': '1234'}, org=self.factory.org) self.assertEqual(response.status_code, 400) def test_bad_token(self): - response = post_request('/invite/{}'.format('jdsnfkjdsnfkj'), data={'password': '1234'}, org=self.factory.org) + response = self.post_request('/invite/{}'.format('jdsnfkjdsnfkj'), data={'password': '1234'}, org=self.factory.org) self.assertEqual(response.status_code, 400) def test_already_active_user(self): @@ -51,7 +50,7 @@ def test_already_active_user(self): def test_valid_password(self): token = invite_token(self.factory.user) password = 'test1234' - response = post_request('/invite/{}'.format(token), data={'password': password}, org=self.factory.org) + response = self.post_request('/invite/{}'.format(token), data={'password': password}, org=self.factory.org) self.assertEqual(response.status_code, 302) user = User.get_by_id(self.factory.user.id) self.assertTrue(user.verify_password(password)) @@ -62,7 +61,7 @@ def test_throttle_login(self): # Extract the limit from settings (ex: '50/day') limit = settings.THROTTLE_LOGIN_PATTERN.split('/')[0] for _ in range(0, int(limit)): - get_request('/login', org=self.factory.org) + self.get_request('/login', org=self.factory.org) - response = get_request('/login', org=self.factory.org) + response = self.get_request('/login', org=self.factory.org) self.assertEqual(response.status_code, 429) diff --git a/tests/handlers/test_users.py b/tests/handlers/test_users.py index f5fb1db2ce..aed7a93743 100644 --- a/tests/handlers/test_users.py +++ b/tests/handlers/test_users.py @@ -1,7 +1,5 @@ -from tests import BaseTestCase -from tests.handlers import authenticated_user, json_request -from redash.wsgi import app from redash import models +from tests import BaseTestCase class TestUserListResourcePost(BaseTestCase): diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 1b10652353..5275975384 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -2,12 +2,12 @@ from flask import request from mock import patch - -from tests import BaseTestCase from redash import models -from redash.authentication.google_oauth import create_and_login_user, verify_profile -from redash.authentication import api_key_load_user_from_request, hmac_load_user_from_request, sign -from redash.wsgi import app +from redash.authentication import (api_key_load_user_from_request, + hmac_load_user_from_request, sign) +from redash.authentication.google_oauth import (create_and_login_user, + verify_profile) +from tests import BaseTestCase class TestApiKeyAuthentication(BaseTestCase): @@ -23,48 +23,40 @@ def setUp(self): self.queries_url = '/{}/api/queries'.format(self.factory.org.slug) def test_no_api_key(self): - with app.test_client() as c: - rv = c.get(self.query_url) - self.assertIsNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.query_url) + self.assertIsNone(api_key_load_user_from_request(request)) def test_wrong_api_key(self): - with app.test_client() as c: - rv = c.get(self.query_url, query_string={'api_key': 'whatever'}) - self.assertIsNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.query_url, query_string={'api_key': 'whatever'}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_correct_api_key(self): - with app.test_client() as c: - rv = c.get(self.query_url, query_string={'api_key': self.api_key}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.query_url, query_string={'api_key': self.api_key}) + self.assertIsNotNone(api_key_load_user_from_request(request)) def test_no_query_id(self): - with app.test_client() as c: - rv = c.get(self.queries_url, query_string={'api_key': self.api_key}) - self.assertIsNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.queries_url, query_string={'api_key': self.api_key}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") models.db.session.flush() - with app.test_client() as c: - rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) - self.assertEqual(user.id, api_key_load_user_from_request(request).id) + rv = self.client.get(self.queries_url, query_string={'api_key': user.api_key}) + self.assertEqual(user.id, api_key_load_user_from_request(request).id) def test_api_key_header(self): - with app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(self.api_key)}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.query_url, headers={'Authorization': "Key {}".format(self.api_key)}) + self.assertIsNotNone(api_key_load_user_from_request(request)) def test_api_key_header_with_wrong_key(self): - with app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key oops"}) - self.assertIsNone(api_key_load_user_from_request(request)) + rv = self.client.get(self.query_url, headers={'Authorization': "Key oops"}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_api_key_for_wrong_org(self): other_user = self.factory.create_admin(org=self.factory.create_org()) - with app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(other_user.api_key)}) - self.assertEqual(404, rv.status_code) + rv = self.client.get(self.query_url, headers={'Authorization': "Key {}".format(other_user.api_key)}) + self.assertEqual(404, rv.status_code) class TestHMACAuthentication(BaseTestCase): @@ -83,33 +75,29 @@ def signature(self, expires): return sign(self.query.api_key, self.path, expires) def test_no_signature(self): - with app.test_client() as c: - rv = c.get(self.path) - self.assertIsNone(hmac_load_user_from_request(request)) + rv = self.client.get(self.path) + self.assertIsNone(hmac_load_user_from_request(request)) def test_wrong_signature(self): - with app.test_client() as c: - rv = c.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) - self.assertIsNone(hmac_load_user_from_request(request)) + rv = self.client.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) + self.assertIsNone(hmac_load_user_from_request(request)) def test_correct_signature(self): - with app.test_client() as c: - rv = c.get(self.path, query_string={'signature': self.signature(self.expires), 'expires': self.expires}) - self.assertIsNotNone(hmac_load_user_from_request(request)) + rv = self.client.get(self.path, query_string={'signature': self.signature(self.expires), 'expires': self.expires}) + self.assertIsNotNone(hmac_load_user_from_request(request)) def test_no_query_id(self): - with app.test_client() as c: - rv = c.get('/{}/api/queries'.format(self.query.org.slug), query_string={'api_key': self.api_key}) - self.assertIsNone(hmac_load_user_from_request(request)) + rv = self.client.get('/{}/api/queries'.format(self.query.org.slug), query_string={'api_key': self.api_key}) + self.assertIsNone(hmac_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") path = '/api/queries/' models.db.session.flush() - with app.test_client() as c: - signature = sign(user.api_key, path, self.expires) - rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) - self.assertEqual(user, hmac_load_user_from_request(request)) + + signature = sign(user.api_key, path, self.expires) + rv = self.client.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) + self.assertEqual(user.id, hmac_load_user_from_request(request).id) class TestCreateAndLoginUser(BaseTestCase): diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 2cff6e8dff..d3adf54bbc 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,20 +1,18 @@ -import json from unittest import TestCase + from flask import url_for from flask_login import current_user from mock import patch -from tests import BaseTestCase -from tests.handlers import authenticated_user from redash import models, settings -from redash.wsgi import app +from tests import BaseTestCase +from tests import authenticated_user class AuthenticationTestMixin(object): def test_returns_404_when_not_unauthenticated(self): - with app.test_client() as c: - for path in self.paths: - rv = c.get(path) - self.assertEquals(404, rv.status_code) + for path in self.paths: + rv = self.client.get(path) + self.assertEquals(404, rv.status_code) def test_returns_content_when_authenticated(self): for path in self.paths: @@ -24,17 +22,15 @@ def test_returns_content_when_authenticated(self): class TestAuthentication(BaseTestCase): def test_redirects_for_nonsigned_in_user(self): - with app.test_client() as c: - rv = c.get("/default/") - self.assertEquals(302, rv.status_code) + rv = self.client.get("/default/") + self.assertEquals(302, rv.status_code) class PingTest(TestCase): def test_ping(self): - with app.test_client() as c: - rv = c.get('/ping') - self.assertEquals(200, rv.status_code) - self.assertEquals('PONG.', rv.data) + rv = self.client.get('/ping') + self.assertEquals(200, rv.status_code) + self.assertEquals('PONG.', rv.data) class IndexTest(BaseTestCase): @@ -43,10 +39,9 @@ def setUp(self): super(IndexTest, self).setUp() def test_redirect_to_login_when_not_authenticated(self): - with app.test_client() as c: - for path in self.paths: - rv = c.get(path) - self.assertEquals(302, rv.status_code) + for path in self.paths: + rv = self.client.get(path) + self.assertEquals(302, rv.status_code) def test_returns_content_when_authenticated(self): for path in self.paths: @@ -66,9 +61,8 @@ def test_returns_403_for_non_admin(self): self.assertEqual(rv.status_code, 403) def test_redirects_non_authenticated_user(self): - with app.test_client() as c: - rv = c.get('/status.json') - self.assertEqual(rv.status_code, 302) + rv = self.client.get('/status.json') + self.assertEqual(rv.status_code, 302) class VisualizationResourceTest(BaseTestCase): @@ -77,7 +71,7 @@ def test_create_visualization(self): data = { 'query_id': query.id, 'name': 'Chart', - 'description':'', + 'description': '', 'options': {}, 'type': 'CHART' } @@ -110,7 +104,7 @@ def test_only_owner_or_admin_can_create_visualization(self): data = { 'query_id': query.id, 'name': 'Chart', - 'description':'', + 'description': '', 'options': {}, 'type': 'CHART' } @@ -131,7 +125,7 @@ def test_only_owner_or_admin_can_create_visualization(self): def test_only_owner_or_admin_can_edit_visualization(self): vis = self.factory.create_visualization() path = '/api/visualizations/{}'.format(vis.id) - data={'name': 'After Update'} + data = {'name': 'After Update'} other_user = self.factory.create_user() admin = self.factory.create_admin() @@ -190,19 +184,18 @@ def tearDownClass(cls): settings.ORG_RESOLVING = "multi_org" def test_redirects_to_google_login_if_password_disabled(self): - with app.test_client() as c, patch.object(settings, 'PASSWORD_LOGIN_ENABLED', False): - rv = c.get('/default/login') + with patch.object(settings, 'PASSWORD_LOGIN_ENABLED', False): + rv = self.client.get('/default/login') self.assertEquals(rv.status_code, 302) self.assertTrue(rv.location.endswith(url_for('google_oauth.authorize', next='/default/'))) def test_get_login_form(self): - with app.test_client() as c: - rv = c.get('/default/login') - self.assertEquals(rv.status_code, 200) + rv = self.client.get('/default/login') + self.assertEquals(rv.status_code, 200) def test_submit_non_existing_user(self): - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': 'arik', 'password': 'password'}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': 'arik', 'password': 'password'}) self.assertEquals(rv.status_code, 200) self.assertFalse(login_user_mock.called) @@ -211,8 +204,8 @@ def test_submit_correct_user_and_password(self): user.hash_password('password') user.save() - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': user.email, 'password': 'password'}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password'}) self.assertEquals(rv.status_code, 302) login_user_mock.assert_called_with(user, remember=False) @@ -221,8 +214,8 @@ def test_submit_correct_user_and_password_and_remember_me(self): user.hash_password('password') user.save() - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': user.email, 'password': 'password', 'remember': True}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password', 'remember': True}) self.assertEquals(rv.status_code, 302) login_user_mock.assert_called_with(user, remember=True) @@ -231,16 +224,16 @@ def test_submit_correct_user_and_password_with_next(self): user.hash_password('password') user.save() - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login?next=/test', - data={'email': user.email, 'password': 'password'}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login?next=/test', + data={'email': user.email, 'password': 'password'}) self.assertEquals(rv.status_code, 302) self.assertEquals(rv.location, 'http://localhost/test') login_user_mock.assert_called_with(user, remember=False) def test_submit_incorrect_user(self): - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': 'non-existing', 'password': 'password'}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': 'non-existing', 'password': 'password'}) self.assertEquals(rv.status_code, 200) self.assertFalse(login_user_mock.called) @@ -249,39 +242,36 @@ def test_submit_incorrect_password(self): user.hash_password('password') user.save() - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': user.email, 'password': 'badbadpassword'}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': user.email, 'password': 'badbadpassword'}) self.assertEquals(rv.status_code, 200) self.assertFalse(login_user_mock.called) - def test_submit_incorrect_password(self): + def test_submit_empty_password(self): user = self.factory.user - with app.test_client() as c, patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.post('/default/login', data={'email': user.email, 'password': ''}) + with patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.post('/default/login', data={'email': user.email, 'password': ''}) self.assertEquals(rv.status_code, 200) self.assertFalse(login_user_mock.called) def test_user_already_loggedin(self): - with app.test_client() as c, authenticated_user(c), patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = c.get('/default/login') + with authenticated_user(self.client), patch('redash.handlers.authentication.login_user') as login_user_mock: + rv = self.client.get('/default/login') self.assertEquals(rv.status_code, 302) self.assertFalse(login_user_mock.called) - # TODO: brute force protection? - class TestLogout(BaseTestCase): def test_logout_when_not_loggedin(self): - with app.test_client() as c: - rv = c.get('/default/logout') - self.assertEquals(rv.status_code, 302) - self.assertFalse(current_user.is_authenticated) + rv = self.client.get('/default/logout') + self.assertEquals(rv.status_code, 302) + self.assertFalse(current_user.is_authenticated) def test_logout_when_loggedin(self): - with app.test_client() as c, authenticated_user(c, user=self.factory.user): - rv = c.get('/default/') + with authenticated_user(self.client, user=self.factory.user): + rv = self.client.get('/default/') self.assertTrue(current_user.is_authenticated) - rv = c.get('/default/logout') + rv = self.client.get('/default/logout') self.assertEquals(rv.status_code, 302) self.assertFalse(current_user.is_authenticated) From 55cb3747ed712b28f9f3a21af05e09f5c53f1201 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 13:22:22 +0200 Subject: [PATCH 09/80] Use db.drop_all/create_all directly --- redash/models.py | 1 + tests/__init__.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redash/models.py b/redash/models.py index a6fb5701da..dc81d93073 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1483,6 +1483,7 @@ def init_db(): def create_db(create_tables, drop_tables): + # TODO: use these methods directly if drop_tables: db.session.rollback() db.drop_all() diff --git a/tests/__init__.py b/tests/__init__.py index a19f59bad4..73d37fb510 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,9 +14,9 @@ os.environ['REDASH_GOOGLE_CLIENT_SECRET'] = "dummy" os.environ['REDASH_MULTI_ORG'] = "true" -import redash.models from redash import create_app from redash import redis_connection +from redash.models import db from redash.utils import json_dumps from tests.factories import Factory, user_factory @@ -44,15 +44,15 @@ class BaseTestCase(TestCase): def setUp(self): self.app = create_app() self.app.config['TESTING'] = True - self.client = self.app.test_client() self.app_ctx = self.app.app_context() self.app_ctx.push() - redash.models.db.session.remove() - redash.models.create_db(True, True) + db.create_all() self.factory = Factory() + self.client = self.app.test_client() def tearDown(self): - redash.models.create_db(False, True) + db.session.remove() + db.drop_all() self.app_ctx.pop() redis_connection.flushdb() From 04447e0df68abe522a7a3ecc459c03883cbf6ee0 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 15:20:10 +0200 Subject: [PATCH 10/80] Fix: connections leaking during tests. --- tests/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/__init__.py b/tests/__init__.py index 73d37fb510..9108343127 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -43,16 +43,18 @@ def authenticated_user(c, user=None): class BaseTestCase(TestCase): def setUp(self): self.app = create_app() + self.db = db self.app.config['TESTING'] = True self.app_ctx = self.app.app_context() self.app_ctx.push() + db.drop_all() db.create_all() self.factory = Factory() self.client = self.app.test_client() def tearDown(self): db.session.remove() - db.drop_all() + db.get_engine(self.app).dispose() self.app_ctx.pop() redis_connection.flushdb() From 2a525210e4196cd58a6193d3a8549687c273999e Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Mon, 28 Nov 2016 18:55:41 +0200 Subject: [PATCH 11/80] Start fixing visualizations tests --- redash/authentication/__init__.py | 7 +++--- redash/handlers/authentication.py | 2 +- redash/handlers/base.py | 5 ++++ redash/handlers/visualizations.py | 2 +- redash/models.py | 42 ++++++++++++++----------------- tests/factories.py | 5 ++-- tests/test_handlers.py | 18 +++++++++---- 7 files changed, 45 insertions(+), 36 deletions(-) diff --git a/redash/authentication/__init__.py b/redash/authentication/__init__.py index 553669fd4b..ddbbedc8e3 100644 --- a/redash/authentication/__init__.py +++ b/redash/authentication/__init__.py @@ -5,7 +5,6 @@ import logging from flask import redirect, request, jsonify, url_for -from sqlalchemy.orm.exc import NoResultFound from redash import models, settings from redash.authentication.org_resolving import current_org @@ -40,7 +39,7 @@ def load_user(user_id): org = current_org._get_current_object() try: return models.User.get_by_id_and_org(user_id, org) - except NoResultFound: + except models.NoResultFound: return None @@ -79,11 +78,11 @@ def get_user_from_api_key(api_key, query_id): org = current_org._get_current_object() try: user = models.User.get_by_api_key_and_org(api_key, org) - except NoResultFound: + except models.NoResultFound: try: api_key = models.ApiKey.get_by_api_key(api_key) user = models.ApiUser(api_key, api_key.org, []) - except NoResultFound: + except models.NoResultFound: if query_id: query = models.Query.get_by_id_and_org(query_id, org) if query and query.api_key == api_key: diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index 9c492f3e05..5da8bb4795 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -110,7 +110,7 @@ def login(org_slug=None): return redirect(next_path) else: flash("Wrong email or password.") - except models.User.DoesNotExist: + except models.NoResultFound: flash("Wrong email or password.") google_auth_url = get_google_auth_url(next_path) diff --git a/redash/handlers/base.py b/redash/handlers/base.py index 465b7480fe..fc537d6639 100644 --- a/redash/handlers/base.py +++ b/redash/handlers/base.py @@ -37,6 +37,11 @@ def current_org(self): def record_event(self, options): record_event(self.current_org, self.current_user, options) + # TODO: this should probably be somewhere else + def update_model(self, model, updates): + for k, v in updates.items(): + setattr(model, k, v) + def record_event(org, user, options): if isinstance(user, ApiUser): diff --git a/redash/handlers/visualizations.py b/redash/handlers/visualizations.py index cd793a1d81..f21cc2e208 100644 --- a/redash/handlers/visualizations.py +++ b/redash/handlers/visualizations.py @@ -35,7 +35,7 @@ def post(self, visualization_id): kwargs.pop('id', None) kwargs.pop('query_id', None) - vis.update_instance(**kwargs) + self.update_model(vis, kwargs) return vis.to_dict(with_query=False) diff --git a/redash/models.py b/redash/models.py index dc81d93073..0eb21d433b 100644 --- a/redash/models.py +++ b/redash/models.py @@ -4,28 +4,23 @@ import itertools import json import logging -import os -import threading import time from funcy import project - from flask_sqlalchemy import SQLAlchemy from flask.ext.sqlalchemy import SignallingSession from flask_login import UserMixin, AnonymousUserMixin from sqlalchemy.dialects import postgresql from sqlalchemy.event import listens_for from sqlalchemy.types import TypeDecorator +from sqlalchemy.orm import object_session +# noinspection PyUnresolvedReferences +from sqlalchemy.orm.exc import NoResultFound from passlib.apps import custom_app_context as pwd_context -from playhouse.gfk import GFKField, BaseModel -from playhouse.postgres_ext import ArrayField, DateTimeTZField - - -from redash import redis_connection, settings, utils +from redash import redis_connection, utils from redash.destinations import get_destination, get_configuration_schema_for_destination_type -from redash.metrics.database import MeteredPostgresqlExtDatabase, MeteredModel from redash.permissions import has_access, view_only from redash.query_runner import get_query_runner, get_configuration_schema_for_query_runner_type from redash.utils import generate_token, json_dumps @@ -39,6 +34,7 @@ # TODO replace this with association tables. _gfk_types = {} + class GFKBase(object): """ Compatibility with 'generic foreign key' approach Peewee used. @@ -121,14 +117,14 @@ def record_changes(self, changed_by): change=changes)) - class ConflictDetectedError(Exception): pass + class BelongsToOrgMixin(object): @classmethod def get_by_id_and_org(cls, object_id, org): - return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one_or_none() + return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one() class PermissionsCheckMixin(object): @@ -143,6 +139,7 @@ def has_permissions(self, permissions): return has_permissions + class AnonymousUser(AnonymousUserMixin, PermissionsCheckMixin): @property def permissions(self): @@ -826,12 +823,6 @@ def tracked_save(self, changing_user, old_object=None, *args, **kwargs): new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self) return new_change - def _create_default_visualizations(self): - table_visualization = Visualization(query=self, name="Table", - description='', - type="TABLE", options="{}") - table_visualization.save() - def _set_api_key(self): if not self.api_key: self.api_key = hashlib.sha1( @@ -863,11 +854,12 @@ def gen_query_hash(target, val, oldval, initiator): def query_last_modified_by(target, val, oldval, initiator): target.last_modified_by_id = val +# Create default (table) visualization: @listens_for(SignallingSession, 'before_flush') def create_defaults(session, ctx, *a): for obj in session.new: if isinstance(obj, Query): - session.add(Visualization(query=obj, name="Table", + session.add(Visualization(query_rel=obj, name="Table", description='', type="TABLE", options="{}")) @@ -1194,7 +1186,6 @@ def tracked_save(self, changing_user, old_object=None, *args, **kwargs): new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self) return new_change - def __unicode__(self): return u"%s=%s" % (self.id, self.name) @@ -1203,7 +1194,8 @@ class Visualization(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) type = Column(db.String(100)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) - query = db.relationship(Query, backref='visualizations') + # query_rel and not query, because db.Model already has query defined. + query_rel = db.relationship(Query, backref='visualizations') name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) options = Column(db.Text) @@ -1222,14 +1214,18 @@ def to_dict(self, with_query=True): } if with_query: - d['query'] = self.query.to_dict() + d['query'] = self.query_rel.to_dict() return d @classmethod def get_by_id_and_org(cls, visualization_id, org): - return cls.select(Visualization, Query).join(Query).where(cls.id == visualization_id, - Query.org == org).get() + if isinstance(org, Organization): + org_id = org.id + else: + org_id = org + + return cls.query.join(Query).filter(cls.id == visualization_id, Query.org_id == org_id).one() def __unicode__(self): return u"%s %s" % (self.id, self.type) diff --git a/tests/factories.py b/tests/factories.py index f23f427965..b1ae77c038 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -24,6 +24,7 @@ def create(self, **override_kwargs): kwargs = self._get_kwargs(override_kwargs) obj = self.model(**kwargs) db.session.add(obj) + db.session.commit() return obj @@ -53,7 +54,7 @@ def __call__(self): type='pg', # If we don't use lambda here it will reuse the same options between tests: options=lambda: ConfigurationContainer.from_json('{"dbname": "test"}'), - org=1) + org_id=1) dashboard_factory = ModelFactory(redash.models.Dashboard, name='test', user=user_factory.create, layout='[]', org=1) @@ -107,7 +108,7 @@ def __call__(self): visualization_factory = ModelFactory(redash.models.Visualization, type='CHART', - query_text=query_factory.create, + query_rel=query_factory.create, name='Chart', description='', options='{}') diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d3adf54bbc..6e7abf36c3 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -26,7 +26,7 @@ def test_redirects_for_nonsigned_in_user(self): self.assertEquals(302, rv.status_code) -class PingTest(TestCase): +class PingTest(BaseTestCase): def test_ping(self): rv = self.client.get('/ping') self.assertEquals(200, rv.status_code) @@ -202,7 +202,9 @@ def test_submit_non_existing_user(self): def test_submit_correct_user_and_password(self): user = self.factory.user user.hash_password('password') - user.save() + + self.db.session.add(user) + self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password'}) @@ -212,7 +214,9 @@ def test_submit_correct_user_and_password(self): def test_submit_correct_user_and_password_and_remember_me(self): user = self.factory.user user.hash_password('password') - user.save() + + self.db.session.add(user) + self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password', 'remember': True}) @@ -222,7 +226,9 @@ def test_submit_correct_user_and_password_and_remember_me(self): def test_submit_correct_user_and_password_with_next(self): user = self.factory.user user.hash_password('password') - user.save() + + self.db.session.add(user) + self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: rv = self.client.post('/default/login?next=/test', @@ -240,7 +246,9 @@ def test_submit_incorrect_user(self): def test_submit_incorrect_password(self): user = self.factory.user user.hash_password('password') - user.save() + + self.db.session.add(user) + self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: rv = self.client.post('/default/login', data={'email': user.email, 'password': 'badbadpassword'}) From 90879c964f3e5ea41a7f846faa98db088da9b5d6 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 00:12:26 +0200 Subject: [PATCH 12/80] Simplify query api key generation --- redash/models.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/redash/models.py b/redash/models.py index 0eb21d433b..f0065ca816 100644 --- a/redash/models.py +++ b/redash/models.py @@ -597,13 +597,6 @@ def should_schedule_next(previous_iteration, now, schedule): return now > next_iteration -def generate_query_api_key(ctx): - return hashlib.sha1(u''.join(( - str(time.time()), ctx.current_parameters['query'], - str(ctx.current_parameters['user_id']), - ctx.current_parameters['name'])).encode('utf-8')).hexdigest() - - class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) version = Column(db.Integer) @@ -617,7 +610,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): description = Column(db.String(4096), nullable=True) query_text = Column("query", db.Text) query_hash = Column(db.String(32)) - api_key = Column(db.String(40), default=generate_query_api_key) + api_key = Column(db.String(40), default=lambda: generate_token(40)) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, foreign_keys=[user_id]) last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True) @@ -800,7 +793,6 @@ def fork(self, user): def pre_save(self, created): super(Query, self).pre_save(created) self.query_hash = utils.gen_query_hash(self.query) - self._set_api_key() if self.last_modified_by is None: self.last_modified_by = self.user @@ -823,11 +815,6 @@ def tracked_save(self, changing_user, old_object=None, *args, **kwargs): new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self) return new_change - def _set_api_key(self): - if not self.api_key: - self.api_key = hashlib.sha1( - u''.join((str(time.time()), self.query, str(self.user_id), self.name)).encode('utf-8')).hexdigest() - @property def runtime(self): return self.latest_query_data.runtime From c2378d837ac74fc69c6ae15b43ced50eece66e61 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Mon, 28 Nov 2016 22:30:35 -0600 Subject: [PATCH 13/80] test_handlers passes --- redash/handlers/authentication.py | 16 ++++--- redash/handlers/visualizations.py | 24 +++++----- redash/models.py | 11 ++--- redash/monitor.py | 10 ++-- tests/__init__.py | 6 ++- tests/factories.py | 8 ++-- tests/handlers/__init__.py | 80 +++++++++++++++++++++++++++++++ tests/test_handlers.py | 49 +++++++++++++------ 8 files changed, 154 insertions(+), 50 deletions(-) diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index 5da8bb4795..c32b9d7394 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -3,6 +3,9 @@ from flask import flash, redirect, render_template, request, url_for from flask_login import current_user, login_required, login_user, logout_user + +from sqlalchemy.orm.exc import NoResultFound + from redash import __version__, limiter, models, settings from redash.authentication import current_org, get_login_url from redash.authentication.account import (BadSignature, SignatureExpired, @@ -27,7 +30,7 @@ def render_token_login_page(template, org_slug, token): try: user_id = validate_token(token) user = models.User.get_by_id_and_org(user_id, current_org) - except models.User.DoesNotExist: + except NoResultFound: logger.exception("Bad user id in token. Token= , User id= %s, Org=%s", user_id, token, org_slug) return render_template("error.html", error_message="Invalid invite link. Please ask for a new one."), 400 except (SignatureExpired, BadSignature): @@ -48,9 +51,9 @@ def render_token_login_page(template, org_slug, token): else: # TODO: set active flag user.hash_password(request.form['password']) - user.save() - + models.db.session.add(user) login_user(user) + models.db.session.commit() return redirect(url_for('redash.index', org_slug=org_slug)) if settings.GOOGLE_OAUTH_ENABLED: google_auth_url = get_google_auth_url(url_for('redash.index', org_slug=org_slug)) @@ -78,7 +81,7 @@ def forgot_password(org_slug=None): try: user = models.User.get_by_email_and_org(email, current_org) send_password_reset_email(user) - except models.User.DoesNotExist: + except NoResultFound: logging.error("No user found for forgot password: %s", email) return render_template("forgot.html", submitted=submitted) @@ -89,7 +92,6 @@ def forgot_password(org_slug=None): def login(org_slug=None): index_url = url_for("redash.index", org_slug=org_slug) next_path = request.args.get('next', index_url) - if current_user.is_authenticated: return redirect(next_path) @@ -103,14 +105,14 @@ def login(org_slug=None): if request.method == 'POST': try: - user = models.User.get_by_email_and_org(request.form['email'], current_org.id) + user = models.User.get_by_email_and_org(request.form['email'], current_org) if user and user.verify_password(request.form['password']): remember = ('remember' in request.form) login_user(user, remember=remember) return redirect(next_path) else: flash("Wrong email or password.") - except models.NoResultFound: + except NoResultFound: flash("Wrong email or password.") google_auth_url = get_google_auth_url(next_path) diff --git a/redash/handlers/visualizations.py b/redash/handlers/visualizations.py index f21cc2e208..67c92927a9 100644 --- a/redash/handlers/visualizations.py +++ b/redash/handlers/visualizations.py @@ -15,18 +15,19 @@ def post(self): require_admin_or_owner(query.user_id) kwargs['options'] = json.dumps(kwargs['options']) - kwargs['query'] = query - - vis = models.Visualization.create(**kwargs) - - return vis.to_dict(with_query=False) + kwargs['query_rel'] = query + vis = models.Visualization(**kwargs) + models.db.session.add(vis) + d = vis.to_dict(with_query=False) + models.db.session.commit() + return d class VisualizationResource(BaseResource): @require_permission('edit_query') def post(self, visualization_id): vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) - require_admin_or_owner(vis.query.user_id) + require_admin_or_owner(vis.query_rel.user_id) kwargs = request.get_json(force=True) if 'options' in kwargs: @@ -36,12 +37,13 @@ def post(self, visualization_id): kwargs.pop('query_id', None) self.update_model(vis, kwargs) - - return vis.to_dict(with_query=False) + d = vis.to_dict(with_query=False) + models.db.session.commit() + return d @require_permission('edit_query') def delete(self, visualization_id): vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) - require_admin_or_owner(vis.query.user_id) - - vis.delete_instance() + require_admin_or_owner(vis.query_rel.user_id) + models.db.session.delete(vis) + models.db.session.commit() diff --git a/redash/models.py b/redash/models.py index f0065ca816..7884ca831d 100644 --- a/redash/models.py +++ b/redash/models.py @@ -707,7 +707,7 @@ def by_user(cls, user, drafts): @classmethod def outdated_queries(cls): - queries = (cls.query(Query) + queries = (db.session.query(Query) .join(QueryResult) .join(DataSource) .filter(Query.schedule != None)) @@ -1207,12 +1207,9 @@ def to_dict(self, with_query=True): @classmethod def get_by_id_and_org(cls, visualization_id, org): - if isinstance(org, Organization): - org_id = org.id - else: - org_id = org - - return cls.query.join(Query).filter(cls.id == visualization_id, Query.org_id == org_id).one() + return db.session.query(Visualization).join(Query).filter( + cls.id == visualization_id, + Query.org == org).one() def __unicode__(self): return u"%s %s" % (self.id, self.type) diff --git a/redash/monitor.py b/redash/monitor.py index d5b8503dfc..b164e33c97 100644 --- a/redash/monitor.py +++ b/redash/monitor.py @@ -6,12 +6,12 @@ def get_status(): info = redis_connection.info() status['redis_used_memory'] = info['used_memory_human'] status['version'] = __version__ - status['queries_count'] = models.Query.select().count() + status['queries_count'] = models.db.session.query(models.Query).count() if settings.FEATURE_SHOW_QUERY_RESULTS_COUNT: - status['query_results_count'] = models.QueryResult.select().count() + status['query_results_count'] = models.db.session.query(models.QueryResult).count() status['unused_query_results_count'] = models.QueryResult.unused().count() - status['dashboards_count'] = models.Dashboard.select().count() - status['widgets_count'] = models.Widget.select().count() + status['dashboards_count'] = models.Dashboard.query.count() + status['widgets_count'] = models.Widget.query.count() status['workers'] = [] @@ -20,7 +20,7 @@ def get_status(): status['manager']['outdated_queries_count'] = len(models.Query.outdated_queries()) queues = {} - for ds in models.DataSource.select(): + for ds in models.DataSource.query: for queue in (ds.queue_name, ds.scheduled_queue_name): queues.setdefault(queue, set()) queues[queue].add(ds.name) diff --git a/tests/__init__.py b/tests/__init__.py index 9108343127..908acd0095 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -34,7 +34,7 @@ def authenticate_request(c, user): def authenticated_user(c, user=None): if not user: user = user_factory.create() - + db.session.commit() authenticate_request(c, user) yield user @@ -47,6 +47,7 @@ def setUp(self): self.app.config['TESTING'] = True self.app_ctx = self.app.app_context() self.app_ctx.push() + db.session.close() db.drop_all() db.create_all() self.factory = Factory() @@ -58,7 +59,8 @@ def tearDown(self): self.app_ctx.pop() redis_connection.flushdb() - def make_request(self, method, path, org=None, user=None, data=None, is_json=True): + def make_request(self, method, path, org=None, user=None, data=None, + is_json=True): if user is None: user = self.factory.user diff --git a/tests/factories.py b/tests/factories.py index b1ae77c038..eff5abf4ce 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -143,7 +143,9 @@ def __init__(self): def user(self): if self._user is None: self._user = self.create_user() - + # Test setup creates users, they need to be in the db by the time + # the handler's db transaction starts. + db.session.commit() return self._user @property @@ -303,14 +305,14 @@ def create_query_result(self, **kwargs): def create_visualization(self, **kwargs): args = { - 'query': self.create_query() + 'query_rel': self.create_query() } args.update(kwargs) return visualization_factory.create(**args) def create_visualization_with_params(self, **kwargs): args = { - 'query': self.create_query_with_params() + 'query_rel': self.create_query_with_params() } args.update(kwargs) return visualization_factory.create(**args) diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py index e69de29bb2..9375c8b4f2 100644 --- a/tests/handlers/__init__.py +++ b/tests/handlers/__init__.py @@ -0,0 +1,80 @@ +import json +from contextlib import contextmanager + +from tests.factories import user_factory +from redash.models import db +from redash.utils import json_dumps +from redash.wsgi import app + +app.config['TESTING'] = True + + +def authenticate_request(c, user): + with c.session_transaction() as sess: + if user.id is None: + db.session.flush() + sess['user_id'] = user.id + + +@contextmanager +def authenticated_user(c, user=None): + if not user: + user = user_factory.create() + db.session.commit() + authenticate_request(c, user) + + yield user + + +def json_request(method, path, data=None): + if data: + response = method(path, data=json_dumps(data)) + else: + response = method(path) + + if response.data: + response.json = json.loads(response.data) + else: + response.json = None + + return response + + +def make_request(method, path, user, data=None, is_json=True): + with app.test_client() as c: + if user: + authenticate_request(c, user) + + method_fn = getattr(c, method.lower()) + headers = {} + + if data and is_json: + data = json_dumps(data) + + if is_json: + content_type = 'application/json' + else: + content_type = None + + response = method_fn(path, data=data, headers=headers, content_type=content_type) + + if response.data and is_json: + response.json = json.loads(response.data) + + return response + + +def get_request(path, org=None): + if org: + path = "/{}{}".format(org.slug, path) + + with app.test_client() as c: + return c.get(path) + + +def post_request(path, data=None, org=None): + if org: + path = "/{}{}".format(org.slug, path) + + with app.test_client() as c: + return c.post(path, data=data) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 6e7abf36c3..2fa1381c4d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -52,7 +52,7 @@ def test_returns_content_when_authenticated(self): class StatusTest(BaseTestCase): def test_returns_data_for_super_admin(self): admin = self.factory.create_admin() - + models.db.session.commit() rv = self.make_request('get', '/status.json', org=False, user=admin, is_json=False) self.assertEqual(rv.status_code, 200) @@ -68,6 +68,7 @@ def test_redirects_non_authenticated_user(self): class VisualizationResourceTest(BaseTestCase): def test_create_visualization(self): query = self.factory.create_query() + models.db.session.commit() data = { 'query_id': query.id, 'name': 'Chart', @@ -84,16 +85,16 @@ def test_create_visualization(self): def test_delete_visualization(self): visualization = self.factory.create_visualization() - + models.db.session.commit() rv = self.make_request('delete', '/api/visualizations/{}'.format(visualization.id)) self.assertEquals(rv.status_code, 200) # =1 because each query has a default table visualization. - self.assertEquals(models.Visualization.select().count(), 1) + self.assertEquals(models.db.session.query(models.Visualization).count(), 1) def test_update_visualization(self): visualization = self.factory.create_visualization() - + models.db.session.commit() rv = self.make_request('post', '/api/visualizations/{0}'.format(visualization.id), data={'name': 'After Update'}) self.assertEquals(rv.status_code, 200) @@ -101,6 +102,13 @@ def test_update_visualization(self): def test_only_owner_or_admin_can_create_visualization(self): query = self.factory.create_query() + other_user = self.factory.create_user() + admin = self.factory.create_admin() + admin_from_diff_org = self.factory.create_admin(org=self.factory.create_org()) + models.db.session.commit() + models.db.session.refresh(admin) + models.db.session.refresh(other_user) + models.db.session.refresh(admin_from_diff_org) data = { 'query_id': query.id, 'name': 'Chart', @@ -109,9 +117,6 @@ def test_only_owner_or_admin_can_create_visualization(self): 'type': 'CHART' } - other_user = self.factory.create_user() - admin = self.factory.create_admin() - admin_from_diff_org = self.factory.create_admin(org=self.factory.create_org()) rv = self.make_request('post', '/api/visualizations', data=data, user=admin) self.assertEquals(rv.status_code, 200) @@ -124,12 +129,17 @@ def test_only_owner_or_admin_can_create_visualization(self): def test_only_owner_or_admin_can_edit_visualization(self): vis = self.factory.create_visualization() + models.db.session.flush() path = '/api/visualizations/{}'.format(vis.id) data = {'name': 'After Update'} other_user = self.factory.create_user() admin = self.factory.create_admin() admin_from_diff_org = self.factory.create_admin(org=self.factory.create_org()) + models.db.session.commit() + models.db.session.refresh(admin) + models.db.session.refresh(other_user) + models.db.session.refresh(admin_from_diff_org) rv = self.make_request('post', path, user=admin, data=data) self.assertEquals(rv.status_code, 200) @@ -142,22 +152,29 @@ def test_only_owner_or_admin_can_edit_visualization(self): def test_only_owner_or_admin_can_delete_visualization(self): vis = self.factory.create_visualization() + models.db.session.flush() path = '/api/visualizations/{}'.format(vis.id) other_user = self.factory.create_user() admin = self.factory.create_admin() admin_from_diff_org = self.factory.create_admin(org=self.factory.create_org()) + models.db.session.commit() + models.db.session.refresh(admin) + models.db.session.refresh(other_user) + models.db.session.refresh(admin_from_diff_org) rv = self.make_request('delete', path, user=admin) self.assertEquals(rv.status_code, 200) vis = self.factory.create_visualization() + models.db.session.commit() path = '/api/visualizations/{}'.format(vis.id) rv = self.make_request('delete', path, user=other_user) self.assertEquals(rv.status_code, 403) vis = self.factory.create_visualization() + models.db.session.commit() path = '/api/visualizations/{}'.format(vis.id) rv = self.make_request('delete', path, user=admin_from_diff_org) @@ -184,7 +201,7 @@ def tearDownClass(cls): settings.ORG_RESOLVING = "multi_org" def test_redirects_to_google_login_if_password_disabled(self): - with patch.object(settings, 'PASSWORD_LOGIN_ENABLED', False): + with patch.object(settings, 'PASSWORD_LOGIN_ENABLED', False), self.app.test_request_context('/default/login'): rv = self.client.get('/default/login') self.assertEquals(rv.status_code, 302) self.assertTrue(rv.location.endswith(url_for('google_oauth.authorize', next='/default/'))) @@ -251,7 +268,8 @@ def test_submit_incorrect_password(self): self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': user.email, 'password': 'badbadpassword'}) + rv = self.client.post('/default/login', data={ + 'email': user.email, 'password': 'badbadpassword'}) self.assertEquals(rv.status_code, 200) self.assertFalse(login_user_mock.called) @@ -272,14 +290,15 @@ def test_user_already_loggedin(self): class TestLogout(BaseTestCase): def test_logout_when_not_loggedin(self): - rv = self.client.get('/default/logout') - self.assertEquals(rv.status_code, 302) - self.assertFalse(current_user.is_authenticated) + with self.app.test_client() as c: + rv = c.get('/default/logout') + self.assertEquals(rv.status_code, 302) + self.assertFalse(current_user.is_authenticated) def test_logout_when_loggedin(self): - with authenticated_user(self.client, user=self.factory.user): - rv = self.client.get('/default/') + with self.app.test_client() as c, authenticated_user(c, user=self.factory.user): + rv = c.get('/default/') self.assertTrue(current_user.is_authenticated) - rv = self.client.get('/default/logout') + rv = c.get('/default/logout') self.assertEquals(rv.status_code, 302) self.assertFalse(current_user.is_authenticated) From 9210f5fb0cb30206d364ac05bd47fe280fdd6836 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 12:06:58 +0200 Subject: [PATCH 14/80] Remove unused code --- tests/handlers/__init__.py | 80 -------------------------------------- 1 file changed, 80 deletions(-) diff --git a/tests/handlers/__init__.py b/tests/handlers/__init__.py index 9375c8b4f2..e69de29bb2 100644 --- a/tests/handlers/__init__.py +++ b/tests/handlers/__init__.py @@ -1,80 +0,0 @@ -import json -from contextlib import contextmanager - -from tests.factories import user_factory -from redash.models import db -from redash.utils import json_dumps -from redash.wsgi import app - -app.config['TESTING'] = True - - -def authenticate_request(c, user): - with c.session_transaction() as sess: - if user.id is None: - db.session.flush() - sess['user_id'] = user.id - - -@contextmanager -def authenticated_user(c, user=None): - if not user: - user = user_factory.create() - db.session.commit() - authenticate_request(c, user) - - yield user - - -def json_request(method, path, data=None): - if data: - response = method(path, data=json_dumps(data)) - else: - response = method(path) - - if response.data: - response.json = json.loads(response.data) - else: - response.json = None - - return response - - -def make_request(method, path, user, data=None, is_json=True): - with app.test_client() as c: - if user: - authenticate_request(c, user) - - method_fn = getattr(c, method.lower()) - headers = {} - - if data and is_json: - data = json_dumps(data) - - if is_json: - content_type = 'application/json' - else: - content_type = None - - response = method_fn(path, data=data, headers=headers, content_type=content_type) - - if response.data and is_json: - response.json = json.loads(response.data) - - return response - - -def get_request(path, org=None): - if org: - path = "/{}{}".format(org.slug, path) - - with app.test_client() as c: - return c.get(path) - - -def post_request(path, data=None, org=None): - if org: - path = "/{}{}".format(org.slug, path) - - with app.test_client() as c: - return c.post(path, data=data) From d59299b85a284564dba6a1805ca01e778e6c3813 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 17:53:06 +0200 Subject: [PATCH 15/80] Fix Alert model tests --- redash/models.py | 15 ++++++--------- tests/models/test_alerts.py | 8 ++++---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/redash/models.py b/redash/models.py index 7884ca831d..4cb2a26ab1 100644 --- a/redash/models.py +++ b/redash/models.py @@ -971,14 +971,13 @@ class Alert(TimestampMixin, db.Model): __tablename__ = 'alerts' @classmethod - def all(cls, groups): - return cls.select(Alert, User, Query)\ + def all(cls, group_ids): + # TODO: there was a join with user here to prevent N+1 queries. need to revisit this. + return db.session.query(Alert)\ .join(Query)\ - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source))\ - .where(DataSourceGroup.group << groups)\ - .switch(Alert)\ - .join(User)\ - .group_by(Alert, User, Query) + .join(DataSourceGroup, DataSourceGroup.data_source_id==Query.data_source_id)\ + .filter(DataSourceGroup.group_id.in_(group_ids))\ + .group_by(Alert) @classmethod def get_by_id_and_org(cls, id, org): @@ -1145,7 +1144,6 @@ def recent(cls, org, group_ids, user_id, for_user=False, limit=20): Event.object_id != None, Event.object_type == 'dashboard', Dashboard.org == org, - Dashboard.is_draft == False, Dashboard.is_archived == False, Dashboard.is_draft == False, DataSourceGroup.group_id.in_(group_ids) | @@ -1154,7 +1152,6 @@ def recent(cls, org, group_ids, user_id, for_user=False, limit=20): .group_by(Event.object_id, Dashboard.id) .order_by(db.desc(db.func.count(0)))) - if for_user: query = query.filter(Event.user_id == user_id) diff --git a/tests/models/test_alerts.py b/tests/models/test_alerts.py index 558ad706df..3d2c3bb34b 100644 --- a/tests/models/test_alerts.py +++ b/tests/models/test_alerts.py @@ -14,15 +14,15 @@ def test_returns_all_alerts_for_given_groups(self): alert1 = self.factory.create_alert(query=query1) alert2 = self.factory.create_alert(query=query2) - alerts = Alert.all(groups=[group, self.factory.default_group]) + alerts = Alert.all(group_ids=[group.id, self.factory.default_group.id]) self.assertIn(alert1, alerts) self.assertIn(alert2, alerts) - alerts = Alert.all(groups=[self.factory.default_group]) + alerts = Alert.all(group_ids=[self.factory.default_group.id]) self.assertIn(alert1, alerts) self.assertNotIn(alert2, alerts) - alerts = Alert.all(groups=[group]) + alerts = Alert.all(group_ids=[group.id]) self.assertNotIn(alert1, alerts) self.assertIn(alert2, alerts) @@ -32,6 +32,6 @@ def test_return_each_alert_only_once(self): alert = self.factory.create_alert() - alerts = Alert.all(groups=[self.factory.default_group, group]) + alerts = Alert.all(group_ids=[self.factory.default_group.id, group.id]) self.assertEqual(1, len(list(alerts))) self.assertIn(alert, alerts) From 8680ebe96f9f71747a52ea279762814cc117c8eb Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 17:53:19 +0200 Subject: [PATCH 16/80] Change Dashboard factory to generate non drafts --- tests/factories.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/factories.py b/tests/factories.py index eff5abf4ce..7931090cd3 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -57,7 +57,11 @@ def __call__(self): org_id=1) dashboard_factory = ModelFactory(redash.models.Dashboard, - name='test', user=user_factory.create, layout='[]', org=1) + name='test', + user=user_factory.create, + layout='[]', + is_draft=False, + org=1) api_key_factory = ModelFactory(redash.models.ApiKey, object=dashboard_factory.create) From c386ff91d686bc4ee24cffac10c724b6db1386b4 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 17:54:57 +0200 Subject: [PATCH 17/80] Fix data source models tests --- redash/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/redash/models.py b/redash/models.py index 4cb2a26ab1..b8fc5b6a00 100644 --- a/redash/models.py +++ b/redash/models.py @@ -402,8 +402,10 @@ def __unicode__(self): @classmethod def create_with_group(cls, *args, **kwargs): - data_source = cls.create(*args, **kwargs) - DataSourceGroup.create(data_source=data_source, group=data_source.org.default_group) + data_source = cls(*args, **kwargs) + data_source_group = DataSourceGroup(data_source=data_source, group=data_source.org.default_group) + + db.session.add_all([data_source, data_source_group]) return data_source def get_schema(self, refresh=False): From 811a4ef24898b074f2e3648bdd24b29915620d97 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 18:14:21 +0200 Subject: [PATCH 18/80] Fix accesspemrissions tests --- redash/models.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/redash/models.py b/redash/models.py index b8fc5b6a00..010ce4afa2 100644 --- a/redash/models.py +++ b/redash/models.py @@ -871,13 +871,26 @@ class AccessPermission(GFKBase, db.Model): @classmethod def grant(cls, obj, access_type, grantee, grantor): - return cls.get_or_create(object_type=obj._meta.db_table, object_id=obj.id, access_type=access_type, grantee=grantee, grantor=grantor)[0] + grant = cls.query.filter(cls.object_type==obj.__tablename__, + cls.object_id==obj.id, + cls.access_type==access_type, + cls.grantee==grantee, + cls.grantor==grantor).one_or_none() + + if not grant: + grant = cls(object_type=obj.__tablename__, + object_id=obj.id, + access_type=access_type, + grantee=grantee, + grantor=grantor) + db.session.add(grant) + + return grant @classmethod def revoke(cls, obj, grantee, access_type=None): - query = cls._query(cls.delete(), obj, access_type, grantee) - - return query.execute() + permissions = cls._query(obj, access_type, grantee) + return permissions.delete() @classmethod def find(cls, obj, access_type=None, grantee=None, grantor=None): @@ -888,18 +901,17 @@ def exists(cls, obj, access_type, grantee): return cls.find(obj, access_type, grantee).count() > 0 @classmethod - def _query(cls, base_query, obj, access_type=None, grantee=None, grantor=None): - q = base_query.where(cls.object_type == obj._meta.db_table) \ - .where(cls.object_id == obj.id) + def _query(cls, obj, access_type=None, grantee=None, grantor=None): + q = cls.query.filter(cls.object_id==obj.id, cls.object_type==obj.__tablename__) if access_type: - q = q.where(AccessPermission.access_type == access_type) + q.filter(AccessPermission.access_type == access_type) if grantee: - q = q.where(AccessPermission.grantee == grantee) + q.filter(AccessPermission.grantee_id == grantee.id) if grantor: - q = q.where(AccessPermission.grantor == grantor) + q.filter(AccessPermission.grantor_id == grantor.id) return q From d1fcb435625872576b4fd118900f647d5d17e250 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Tue, 29 Nov 2016 09:48:33 -0600 Subject: [PATCH 19/80] test_alerts passes --- redash/handlers/alerts.py | 43 ++++++++++++++++++++--------------- redash/handlers/base.py | 13 +++++++---- redash/models.py | 19 ++++++++-------- redash/permissions.py | 2 +- tests/factories.py | 4 ++-- tests/handlers/test_alerts.py | 33 +++++++++++++++------------ 6 files changed, 65 insertions(+), 49 deletions(-) diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index 4188889411..f704aa388e 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -21,9 +21,6 @@ def post(self, alert_id): alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) require_admin_or_owner(alert.user.id) - if 'query_id' in params: - params['query'] = params.pop('query_id') - alert.update_instance(**params) self.record_event({ @@ -33,12 +30,16 @@ def post(self, alert_id): 'object_type': 'alert' }) - return alert.to_dict() + d = alert.to_dict() + models.db.session.commit() + return d def delete(self, alert_id): - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, + self.current_org) require_admin_or_owner(alert.user.id) - alert.delete_instance(recursive=True) + models.db.session.delete(alert) + models.db.session.commit() class AlertListResource(BaseResource): @@ -46,16 +47,19 @@ def post(self): req = request.get_json(True) require_fields(req, ('options', 'name', 'query_id')) - query = models.Query.get_by_id_and_org(req['query_id'], self.current_org) + query = models.Query.get_by_id_and_org(req['query_id'], + self.current_org) require_access(query.groups, self.current_user, view_only) - alert = models.Alert.create( + alert = models.Alert( name=req['name'], - query=query, + query_rel=query, user=self.current_user, options=req['options'] ) + models.db.session.add(alert) + models.db.session.flush() self.record_event({ 'action': 'create', 'timestamp': int(time.time()), @@ -63,7 +67,9 @@ def post(self): 'object_type': 'alert' }) - return alert.to_dict() + a = alert.to_dict() + models.db.session.commit() + return a @require_permission('list_alerts') def get(self): @@ -82,8 +88,8 @@ def post(self, alert_id): destination = models.NotificationDestination.get_by_id_and_org(req['destination_id'], self.current_org) kwargs['destination'] = destination - subscription = models.AlertSubscription.create(**kwargs) - + subscription = models.AlertSubscription(**kwargs) + models.db.session.add(subscription) self.record_event({ 'action': 'subscribe', 'timestamp': int(time.time()), @@ -92,9 +98,12 @@ def post(self, alert_id): 'destination': req.get('destination_id') }) - return subscription.to_dict() + d = subscription.to_dict() + models.db.session.commit() + return d def get(self, alert_id): + alert_id = int(alert_id) alert = models.Alert.get_by_id_and_org(alert_id, self.current_org) require_access(alert.groups, self.current_user, view_only) @@ -104,15 +113,13 @@ def get(self, alert_id): class AlertSubscriptionResource(BaseResource): def delete(self, alert_id, subscriber_id): - - subscription = get_object_or_404(models.AlertSubscription.get_by_id, subscriber_id) + subscription = models.AlertSubscription.query.get_or_404(subscriber_id) require_admin_or_owner(subscription.user.id) - subscription.delete_instance() - + models.db.session.delete(subscription) self.record_event({ 'action': 'unsubscribe', 'timestamp': int(time.time()), 'object_id': alert_id, 'object_type': 'alert' }) - + models.db.session.commit() diff --git a/redash/handlers/base.py b/redash/handlers/base.py index fc537d6639..6c24f3bab8 100644 --- a/redash/handlers/base.py +++ b/redash/handlers/base.py @@ -3,7 +3,8 @@ from flask import Blueprint, current_app, request from flask_login import current_user, login_required from flask_restful import Resource, abort -from sqlalchemy.exc import DataError + +from sqlalchemy.orm.exc import NoResultFound from redash import settings from redash.authentication import current_org @@ -73,11 +74,13 @@ def require_fields(req, fields): def get_object_or_404(fn, *args, **kwargs): - rv = fn(*args, **kwargs) - if rv is None: + try: + rv = fn(*args, **kwargs) + if rv is None: + abort(404) + except NoResultFound: abort(404) - else: - return rv + return rv def paginate(query_set, page, page_size, serializer): diff --git a/redash/models.py b/redash/models.py index 010ce4afa2..5d7798b7dc 100644 --- a/redash/models.py +++ b/redash/models.py @@ -621,6 +621,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) schedule = Column(db.String(10), nullable=True) + visualizations = db.relationship("Visualization", cascade="all, delete-orphan") options = Column(PseudoJSON, default={}) __tablename__ = 'queries' @@ -973,12 +974,12 @@ class Alert(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) name = Column(db.String(255)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) - query = db.relationship(Query, backref='alerts') + query_rel = db.relationship(Query, backref='alerts', cascade="all") user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, backref='alerts') options = Column(PseudoJSON) state = Column(db.String(255), default=UNKNOWN_STATE) - subscriptions = db.relationship("AlertSubscription", cascade="delete") + subscriptions = db.relationship("AlertSubscription", cascade="all, delete-orphan") last_triggered_at = Column(db.DateTime(True), nullable=True) rearm = Column(db.Integer, nullable=True) @@ -995,7 +996,7 @@ def all(cls, group_ids): @classmethod def get_by_id_and_org(cls, id, org): - return cls.select(Alert, User, Query).join(Query).switch(Alert).join(User).where(cls.id==id, Query.org==org).get() + return db.session.query(Alert).join(Query).filter(Alert.id==id, Query.org==org).one() def to_dict(self, full=True): d = { @@ -1010,7 +1011,7 @@ def to_dict(self, full=True): } if full: - d['query'] = self.query.to_dict() + d['query'] = self.query_rel.to_dict() d['user'] = self.user.to_dict() else: d['query_id'] = self.query_id @@ -1019,7 +1020,7 @@ def to_dict(self, full=True): return d def evaluate(self): - data = json.loads(self.query.latest_query_data.data) + data = json.loads(self.query_rel.latest_query_data.data) # todo: safe guard for empty value = data['rows'][0][self.options['column']] op = self.options['op'] @@ -1036,11 +1037,11 @@ def evaluate(self): return new_state def subscribers(self): - return User.select().join(AlertSubscription).where(AlertSubscription.alert==self) + return User.query.join(AlertSubscription).filter(AlertSubscription.alert == self) @property def groups(self): - return self.query.groups + return self.query_rel.groups def generate_slug(ctx): @@ -1193,7 +1194,7 @@ class Visualization(TimestampMixin, db.Model): type = Column(db.String(100)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) # query_rel and not query, because db.Model already has query defined. - query_rel = db.relationship(Query, backref='visualizations') + query_rel = db.relationship(Query, back_populates='visualizations') name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) options = Column(db.Text) @@ -1416,7 +1417,7 @@ def to_dict(self): @classmethod def all(cls, alert_id): - return AlertSubscription.select(AlertSubscription, User).join(User).where(AlertSubscription.alert==alert_id) + return AlertSubscription.query.join(User).filter(AlertSubscription.alert_id == alert_id) def notify(self, alert, query, user, new_state, app, host): if self.destination: diff --git a/redash/permissions.py b/redash/permissions.py index 8f29cc77b2..d5dd7fa849 100644 --- a/redash/permissions.py +++ b/redash/permissions.py @@ -17,7 +17,7 @@ def has_access(object_groups, user, need_view_only): if 'admin' in user.permissions: return True - matching_groups = set(object_groups.keys()).intersection(user.groups) + matching_groups = set(object_groups.keys()).intersection(user.group_ids) if not matching_groups: return False diff --git a/tests/factories.py b/tests/factories.py index 7931090cd3..29de2fb58b 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -97,7 +97,7 @@ def __call__(self): alert_factory = ModelFactory(redash.models.Alert, name=Sequence('Alert {}'), - query=query_factory.create, + query_rel=query_factory.create, user=user_factory.create, options={}) @@ -224,7 +224,7 @@ def create_group_hack(self, **kwargs): def create_alert(self, **kwargs): args = { 'user': self.user, - 'query': self.create_query() + 'query_rel': self.create_query() } args.update(**kwargs) diff --git a/tests/handlers/test_alerts.py b/tests/handlers/test_alerts.py index 0053fbc10c..0bee3e3d25 100644 --- a/tests/handlers/test_alerts.py +++ b/tests/handlers/test_alerts.py @@ -1,5 +1,5 @@ from tests import BaseTestCase -from redash.models import AlertSubscription, Alert +from redash.models import AlertSubscription, Alert, db class TestAlertResourceGet(BaseTestCase): @@ -12,8 +12,8 @@ def test_returns_200_if_allowed(self): def test_returns_403_if_not_allowed(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query=query) - + alert = self.factory.create_alert(query_rel=query) + db.session.commit() rv = self.make_request('get', "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 403) @@ -31,12 +31,12 @@ class TestAlertResourceDelete(BaseTestCase): def test_removes_alert_and_subscriptions(self): subscription = self.factory.create_alert_subscription() alert = subscription.alert - + db.session.commit() rv = self.make_request('delete', "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 200) - self.assertRaises(Alert.DoesNotExist, Alert.get_by_id, subscription.alert.id) - self.assertRaises(AlertSubscription.DoesNotExist, AlertSubscription.get_by_id, subscription.id) + self.assertEqual(Alert.query.get(subscription.alert.id), None) + self.assertEqual(AlertSubscription.query.get(subscription.id), None) def test_returns_403_if_not_allowed(self): alert = self.factory.create_alert() @@ -61,7 +61,7 @@ class TestAlertListPost(BaseTestCase): def test_returns_200_if_has_access_to_query(self): query = self.factory.create_query() destination = self.factory.create_destination() - + db.session.commit() rv = self.make_request('post', "/api/alerts", data=dict(name='Alert', query_id=query.id, destination_id=destination.id, options={})) self.assertEqual(rv.status_code, 200) @@ -70,7 +70,7 @@ def test_fails_if_doesnt_have_access_to_query(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) destination = self.factory.create_destination() - + db.session.commit() rv = self.make_request('post', "/api/alerts", data=dict(name='Alert', query_id=query.id, destination_id=destination.id, options={})) self.assertEqual(rv.status_code, 403) @@ -88,7 +88,7 @@ def test_subscribers_user_to_alert(self): def test_doesnt_subscribers_user_to_alert_without_access(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query=query) + alert = self.factory.create_alert(query_rel=query) destination = self.factory.create_destination() rv = self.make_request('post', "/api/alerts/{}/subscriptions".format(alert.id), data=dict(destination_id=destination.id)) @@ -106,7 +106,7 @@ def test_returns_subscribers(self): def test_doesnt_return_subscribers_when_not_allowed(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) - alert = self.factory.create_alert(query=query) + alert = self.factory.create_alert(query_rel=query) rv = self.make_request('get', "/api/alerts/{}/subscriptions".format(alert.id)) self.assertEqual(rv.status_code, 403) @@ -117,7 +117,8 @@ def test_only_subscriber_or_admin_can_unsubscribe(self): subscription = self.factory.create_alert_subscription() alert = subscription.alert user = subscription.user - path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, subscription.id) + path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, + subscription.id) other_user = self.factory.create_user() @@ -127,7 +128,11 @@ def test_only_subscriber_or_admin_can_unsubscribe(self): response = self.make_request('delete', path, user=user) self.assertEqual(response.status_code, 200) - subscription_two = AlertSubscription.create(alert=alert, user=other_user) - path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, subscription_two.id) - response = self.make_request('delete', path, user=self.factory.create_admin()) + subscription_two = AlertSubscription(alert=alert, user=other_user) + admin_user = self.factory.create_admin() + db.session.add_all([subscription_two, admin_user]) + db.session.commit() + path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, + subscription_two.id) + response = self.make_request('delete', path, user=admin_user) self.assertEqual(response.status_code, 200) From bb755b5c259e862e54679a99a4641fd730d3fa2f Mon Sep 17 00:00:00 2001 From: Allen Short Date: Tue, 29 Nov 2016 09:48:39 -0600 Subject: [PATCH 20/80] test_models fixes --- redash/models.py | 7 ++-- tests/factories.py | 2 +- tests/test_models.py | 83 ++++++++++++++++++++++++-------------------- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/redash/models.py b/redash/models.py index 5d7798b7dc..693b295a7a 100644 --- a/redash/models.py +++ b/redash/models.py @@ -557,7 +557,6 @@ def store_result(cls, org, data_source, query_hash, query, data, run_time, retri data=data) db.session.add(query_result) logging.info("Inserted query (%s) data; id=%s", query_hash, query_result.id) - # TODO: Investigate how big an impact this select-before-update makes. queries = db.session.query(Query).filter( Query.query_hash == query_hash, @@ -736,17 +735,17 @@ def search(cls, term, groups): where &= Query.is_archived == False where &= DataSourceGroup.group_id.in_([g.id for g in groups]) query_ids = ( - cls.query(Query.id).join( + db.session.query(Query.id).join( DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .filter(where)).distinct() - return db.session.query(Query).join(User, Query.user_id == User.id).filter( + return Query.query.join(User, Query.user_id == User.id).filter( Query.id.in_(query_ids)) @classmethod def recent(cls, groups, user_id=None, limit=20): - query = (cls.query(Query).join(User, Query.user_id == User.id) + query = (cls.query.join(User, Query.user_id == User.id) .filter(Event.created_at > (db.func.current_date() - 7)) .join(Event, Query.id == Event.object_id.cast(db.Integer)) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) diff --git a/tests/factories.py b/tests/factories.py index 29de2fb58b..c5b4645a31 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -105,7 +105,7 @@ def __call__(self): data='{"columns":{}, "rows":[]}', runtime=1, retrieved_at=utcnow, - query_text="SELECT 1", + query="SELECT 1", query_hash=gen_query_hash('SELECT 1'), data_source=data_source_factory.create, org_id=1) diff --git a/tests/test_models.py b/tests/test_models.py index ab8eea5c7e..8f58646c35 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,8 +31,7 @@ def test_changing_query_text_changes_hash(self): q = self.factory.create_query() old_hash = q.query_hash - q.query = "SELECT 2;" - #q = db.session.query(models.Query).get(q.id) + q.query_text = "SELECT 2;" db.session.flush() self.assertNotEquals(old_hash, q.query_hash) @@ -136,13 +135,14 @@ def test_recent_excludes_drafts(self): q1 = self.factory.create_query() q2 = self.factory.create_query(is_draft=True) - models.Event.create(org=self.factory.org, user=self.factory.user, - action="edit", object_type="query", - object_id=q1.id) - models.Event.create(org=self.factory.org, user=self.factory.user, - action="edit", object_type="query", - object_id=q2.id) - + models.db.session.add_all([ + models.Event(org=self.factory.org, user=self.factory.user, + action="edit", object_type="query", + object_id=q1.id), + models.Event(org=self.factory.org, user=self.factory.user, + action="edit", object_type="query", + object_id=q2.id) + ]) recent = models.Query.recent([self.factory.default_group]) self.assertIn(q1, recent) @@ -222,7 +222,7 @@ def test_outdated_queries_skips_unscheduled_queries(self): def test_outdated_queries_works_with_ttl_based_schedule(self): two_hours_ago = utcnow() - datetime.timedelta(hours=2) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query.query, retrieved_at=two_hours_ago) + query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=two_hours_ago) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -231,7 +231,7 @@ def test_outdated_queries_works_with_ttl_based_schedule(self): def test_skips_fresh_queries(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago) + query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=half_an_hour_ago) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -240,7 +240,7 @@ def test_skips_fresh_queries(self): def test_outdated_queries_works_with_specific_time_schedule(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule=half_an_hour_ago.strftime('%H:%M')) - query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) + query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -262,7 +262,7 @@ def test_archived_query_doesnt_return_in_all(self): query = self.factory.create_query(schedule="1") yesterday = utcnow() - datetime.timedelta(days=1) query_result, _ = models.QueryResult.store_result( - query.org, query.data_source, query.query_hash, query.query, + query.org, query.data_source, query.query_hash, query.query_text, "1", 123, yesterday) query.latest_query_data = query_result @@ -277,7 +277,7 @@ def test_archived_query_doesnt_return_in_all(self): def test_removes_associated_widgets_from_dashboards(self): widget = self.factory.create_widget() - query = widget.visualization.query + query = widget.visualization.query_rel db.session.commit() query.archive() db.session.flush() @@ -292,7 +292,7 @@ def test_removes_scheduling(self): def test_deletes_alerts(self): subscription = self.factory.create_alert_subscription() - query = subscription.alert.query + query = subscription.alert.query_rel db.session.commit() query.archive() db.session.flush() @@ -497,9 +497,9 @@ def test_stores_the_result(self): self.assertEqual(query_result.data_source, self.data_source) def test_updates_existing_queries(self): - query1 = self.factory.create_query(query=self.query) - query2 = self.factory.create_query(query=self.query) - query3 = self.factory.create_query(query=self.query) + query1 = self.factory.create_query(query_text=self.query) + query2 = self.factory.create_query(query_text=self.query) + query3 = self.factory.create_query(query_text=self.query) query_result, _ = models.QueryResult.store_result( self.data_source.org, self.data_source, self.query_hash, @@ -510,22 +510,22 @@ def test_updates_existing_queries(self): self.assertEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_hash(self): - query1 = self.factory.create_query(query=self.query) - query2 = self.factory.create_query(query=self.query) - query3 = self.factory.create_query(query=self.query + "123") + query1 = self.factory.create_query(query_text=self.query) + query2 = self.factory.create_query(query_text=self.query) + query3 = self.factory.create_query(query_text=self.query + "123") query_result, _ = models.QueryResult.store_result( self.data_source.org, self.data_source, self.query_hash, self.query, self.data, self.runtime, self.utcnow) - + self.assertEqual(query1.latest_query_data, query_result) self.assertEqual(query2.latest_query_data, query_result) self.assertNotEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_data_source(self): - query1 = self.factory.create_query(query=self.query) - query2 = self.factory.create_query(query=self.query) - query3 = self.factory.create_query(query=self.query, data_source=self.factory.create_data_source()) + query1 = self.factory.create_query(query_text=self.query) + query2 = self.factory.create_query(query_text=self.query) + query3 = self.factory.create_query(query_text=self.query, data_source=self.factory.create_data_source()) query_result, _ = models.QueryResult.store_result( self.data_source.org, self.data_source, self.query_hash, @@ -608,14 +608,16 @@ def _set_up_dashboard_test(d): ]) d.q1 = d.factory.create_query(data_source=d.ds1) d.q2 = d.factory.create_query(data_source=d.ds2) - d.v1 = d.factory.create_visualization(query=d.q1) - d.v2 = d.factory.create_visualization(query=d.q2) + d.v1 = d.factory.create_visualization(query_rel=d.q1) + d.v2 = d.factory.create_visualization(query_rel=d.q2) d.w1 = d.factory.create_widget(visualization=d.v1) d.w2 = d.factory.create_widget(visualization=d.v2) d.w3 = d.factory.create_widget(visualization=d.v2, dashboard=d.w2.dashboard) d.w4 = d.factory.create_widget(visualization=d.v2) d.w5 = d.factory.create_widget(visualization=d.v1, dashboard=d.w4.dashboard) - + d.w1.dashboard.is_draft = False + d.w2.dashboard.is_draft = False + d.w4.dashboard.is_draft = False class TestDashboardAll(BaseTestCase): def setUp(self): @@ -675,25 +677,28 @@ def setUp(self): _set_up_dashboard_test(self) def test_returns_recent_dashboards_basic(self): - db.session.flush() db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", object_type="dashboard", object_id=self.w1.dashboard.id)) + db.session.flush() self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u2.group_ids, None)) def test_recent_excludes_drafts(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w2.dashboard.id) - - self.w2.dashboard.update_instance(is_draft=True) - self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) + models.db.session.add_all([ + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id), + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w2.dashboard.id)]) + + self.w2.dashboard.is_draft = True + self.assertIn(self.w1.dashboard, models.Dashboard.recent( + self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(self.w2.dashboard, models.Dashboard.recent( + self.u1.org, self.u1.group_ids, None)) def test_returns_recent_dashboards_created_by_user(self): - d1 = self.factory.create_dashboard(user=self.u1) + d1 = self.factory.create_dashboard(user=self.u1, is_draft=False) db.session.flush() db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", object_type="dashboard", object_id=d1.id)) @@ -703,6 +708,7 @@ def test_returns_recent_dashboards_created_by_user(self): def test_returns_recent_dashboards_with_no_visualizations(self): w1 = self.factory.create_widget(visualization=None) + w1.dashboard.is_draft = False db.session.flush() db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", object_type="dashboard", object_id=w1.dashboard.id)) @@ -744,6 +750,7 @@ def test_returns_each_dashboard_once(self): def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) + w1.dashboard.is_draft = False db.session.flush() db.session.add(models.Event( org=self.factory.org, user=self.u1, action="view", From dff39a6849bee3f9be80ebb978b13f6bc504c712 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Tue, 29 Nov 2016 14:58:04 -0600 Subject: [PATCH 21/80] test_dashboards passes --- redash/handlers/dashboards.py | 32 ++++++++++++++++++++------- redash/models.py | 21 +++++++++--------- tests/handlers/test_authentication.py | 2 +- tests/handlers/test_dashboards.py | 17 +++++--------- 4 files changed, 41 insertions(+), 31 deletions(-) diff --git a/redash/handlers/dashboards.py b/redash/handlers/dashboards.py index 0fcb8eea17..ef2a35847a 100644 --- a/redash/handlers/dashboards.py +++ b/redash/handlers/dashboards.py @@ -3,9 +3,10 @@ from flask import request, url_for from flask_restful import abort from funcy import distinct, project, take +from sqlalchemy.orm.exc import StaleDataError + from redash import models, serializers from redash.handlers.base import BaseResource, get_object_or_404 -from redash.models import ConflictDetectedError from redash.permissions import (can_modify, require_admin_or_owner, require_object_modify_permission, require_permission) @@ -37,9 +38,10 @@ def post(self): user=self.current_user, is_draft=True, layout='[]') + models.db.session.add(dashboard) + models.db.session.commit() return dashboard.to_dict() - class DashboardResource(BaseResource): @require_permission('list_dashboards') def get(self, dashboard_slug=None): @@ -63,13 +65,23 @@ def post(self, dashboard_slug): require_object_modify_permission(dashboard, self.current_user) + updates = project(dashboard_properties, ('name', 'layout', 'version', 'is_draft')) + + # SQLAlchemy handles the case where a concurrent transaction beats us + # to the update. But we still have to make sure that we're not starting + # out behind. + if 'version' in updates and updates['version'] != dashboard.version: + abort(409) + updates['changed_by'] = self.current_user + self.update_model(dashboard, updates) + models.db.session.add(dashboard) try: - dashboard.update_instance(**updates) - except ConflictDetectedError: + models.db.session.commit() + except StaleDataError: abort(409) result = dashboard.to_dict(with_widgets=True, user=self.current_user) @@ -79,9 +91,11 @@ def post(self, dashboard_slug): def delete(self, dashboard_slug): dashboard = models.Dashboard.get_by_slug_and_org(dashboard_slug, self.current_org) dashboard.is_archived = True - dashboard.save(changed_by=self.current_user) - - return dashboard.to_dict(with_widgets=True, user=self.current_user) + dashboard.record_changes(changed_by=self.current_user) + models.db.session.add(dashboard) + d = dashboard.to_dict(with_widgets=True, user=self.current_user) + models.db.session.commit() + return d class PublicDashboardResource(BaseResource): @@ -100,6 +114,7 @@ def post(self, dashboard_id): dashboard = models.Dashboard.get_by_id_and_org(dashboard_id, self.current_org) require_admin_or_owner(dashboard.user_id) api_key = models.ApiKey.create_for_object(dashboard, self.current_user) + models.db.session.flush() public_url = url_for('redash.public_dashboard', token=api_key.api_key, org_slug=self.current_org.slug, _external=True) self.record_event({ @@ -117,10 +132,11 @@ def delete(self, dashboard_id): if api_key: api_key.active = False - api_key.save() + models.db.session.add(api_key) self.record_event({ 'action': 'deactivate_api_key', 'object_id': dashboard.id, 'object_type': 'dashboard', }) + models.db.session.commit() diff --git a/redash/models.py b/redash/models.py index 693b295a7a..ae04d877ff 100644 --- a/redash/models.py +++ b/redash/models.py @@ -894,7 +894,7 @@ def revoke(cls, obj, grantee, access_type=None): @classmethod def find(cls, obj, access_type=None, grantee=None, grantor=None): - return cls._query(cls.select(cls), obj, access_type, grantee, grantor) + return cls._query(obj, access_type, grantee, grantor) @classmethod def exists(cls, obj, access_type, grantee): @@ -902,7 +902,8 @@ def exists(cls, obj, access_type, grantee): @classmethod def _query(cls, obj, access_type=None, grantee=None, grantor=None): - q = cls.query.filter(cls.object_id==obj.id, cls.object_type==obj.__tablename__) + q = cls.query.filter(cls.object_id == obj.id, + cls.object_type == obj.__tablename__) if access_type: q.filter(AccessPermission.access_type == access_type) @@ -1076,18 +1077,14 @@ def to_dict(self, with_widgets=False, user=None): layout = json.loads(self.layout) if with_widgets: - widget_list = Widget.select(Widget, Visualization, Query, User)\ - .where(Widget.dashboard == self.id)\ - .join(Visualization, join_type=peewee.JOIN_LEFT_OUTER)\ - .join(Query, join_type=peewee.JOIN_LEFT_OUTER)\ - .join(User, join_type=peewee.JOIN_LEFT_OUTER) + widget_list = Widget.query.filter(Widget.dashboard == self) widgets = {} for w in widget_list: if w.visualization_id is None: widgets[w.id] = w.to_dict() - elif user and has_access(w.visualization.query.groups, user, view_only): + elif user and has_access(w.visualization.query_rel.groups, user, view_only): widgets[w.id] = w.to_dict() else: widgets[w.id] = project(w.to_dict(), @@ -1175,7 +1172,7 @@ def recent(cls, org, group_ids, user_id, for_user=False, limit=20): @classmethod def get_by_slug_and_org(cls, slug, org): - return cls.get(cls.slug == slug, cls.org==org) + return cls.query.filter(cls.slug == slug, cls.org==org).one() def tracked_save(self, changing_user, old_object=None, *args, **kwargs): self.version += 1 @@ -1330,11 +1327,13 @@ def get_by_api_key(cls, api_key): @classmethod def get_by_object(cls, object): - return cls.select().where(cls.object_type==object._meta.db_table, cls.object_id==object.id, cls.active==True).first() + return cls.query.filter(cls.object_type==object.__class__.__tablename__, cls.object_id==object.id, cls.active==True).first() @classmethod def create_for_object(cls, object, user): - return cls.create(org=user.org, object=object, created_by=user) + k = cls(org=user.org, object=object, created_by=user) + db.session.add(k) + return k class NotificationDestination(BelongsToOrgMixin, db.Model): diff --git a/tests/handlers/test_authentication.py b/tests/handlers/test_authentication.py index c7989e5d89..550616c04c 100644 --- a/tests/handlers/test_authentication.py +++ b/tests/handlers/test_authentication.py @@ -52,7 +52,7 @@ def test_valid_password(self): password = 'test1234' response = self.post_request('/invite/{}'.format(token), data={'password': password}, org=self.factory.org) self.assertEqual(response.status_code, 302) - user = User.get_by_id(self.factory.user.id) + user = User.query.get(self.factory.user.id) self.assertTrue(user.verify_password(password)) diff --git a/tests/handlers/test_dashboards.py b/tests/handlers/test_dashboards.py index e367094c77..f2602427fb 100644 --- a/tests/handlers/test_dashboards.py +++ b/tests/handlers/test_dashboards.py @@ -1,6 +1,6 @@ import json from tests import BaseTestCase -from redash.models import ApiKey, Dashboard, AccessPermission +from redash.models import ApiKey, Dashboard, AccessPermission, db from redash.permissions import ACCESS_TYPE_MODIFY @@ -30,15 +30,14 @@ def test_get_dashboard_filters_unauthorized_widgets(self): restricted_ds = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=restricted_ds) - vis = self.factory.create_visualization(query=query) + vis = self.factory.create_visualization(query_rel=query) restricted_widget = self.factory.create_widget(visualization=vis, dashboard=dashboard) widget = self.factory.create_widget(dashboard=dashboard) dashboard.layout = '[[{}, {}]]'.format(widget.id, restricted_widget.id) - dashboard.save() + db.session.commit() rv = self.make_request('get', '/api/dashboards/{0}'.format(dashboard.slug)) self.assertEquals(rv.status_code, 200) - self.assertTrue(rv.json['widgets'][0][1]['restricted']) self.assertNotIn('restricted', rv.json['widgets'][0][0]) @@ -59,8 +58,7 @@ def test_update_dashboard(self): def test_raises_error_in_case_of_conflict(self): d = self.factory.create_dashboard() d.name = 'Updated' - d.save() - + db.session.commit() new_name = 'New Name' rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), data={'name': new_name, 'layout': '[]', 'version': d.version - 1}) @@ -70,7 +68,6 @@ def test_raises_error_in_case_of_conflict(self): def test_overrides_existing_if_no_version_specified(self): d = self.factory.create_dashboard() d.name = 'Updated' - d.save() new_name = 'New Name' rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), @@ -122,8 +119,7 @@ def test_requires_admin_or_owner(self): res = self.make_request('post', '/api/dashboards/{}/share'.format(dashboard.id), user=user) self.assertEqual(res.status_code, 403) - user.groups.append(self.factory.org.admin_group.id) - user.save() + user.group_ids.append(self.factory.org.admin_group.id) res = self.make_request('post', '/api/dashboards/{}/share'.format(dashboard.id), user=user) self.assertEqual(res.status_code, 200) @@ -151,8 +147,7 @@ def test_requires_admin_or_owner(self): res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id), user=user) self.assertEqual(res.status_code, 403) - user.groups.append(self.factory.org.admin_group.id) - user.save() + user.group_ids.append(self.factory.org.admin_group.id) res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id), user=user) self.assertEqual(res.status_code, 200) From c6ef6041cf63a7de3ac99bc6d2f12d126833d431 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Tue, 29 Nov 2016 18:20:36 -0600 Subject: [PATCH 22/80] test_data_sources passes --- redash/handlers/data_sources.py | 18 +++++++++++------- redash/models.py | 11 ++++++----- redash/utils/configuration.py | 17 ++++++++++++++++- tests/handlers/test_data_sources.py | 16 +++++++--------- 4 files changed, 40 insertions(+), 22 deletions(-) diff --git a/redash/handlers/data_sources.py b/redash/handlers/data_sources.py index c27e8b8cdc..2ba19bc1b7 100644 --- a/redash/handlers/data_sources.py +++ b/redash/handlers/data_sources.py @@ -30,16 +30,16 @@ def post(self, data_source_id): schema = get_configuration_schema_for_query_runner_type(req['type']) if schema is None: abort(400) - try: data_source.options.set_schema(schema) data_source.options.update(req['options']) except ValidationError: abort(400) - + data_source.type = req['type'] data_source.name = req['name'] - data_source.save() + models.db.session.add(data_source) + models.db.session.commit() return data_source.to_dict(all=True) @@ -57,7 +57,9 @@ def get(self): if self.current_user.has_permission('admin'): data_sources = models.DataSource.all(self.current_org) else: - data_sources = models.DataSource.all(self.current_org, groups=self.current_user.groups) + data_sources = models.DataSource.all( + self.current_org, + group_ids=self.current_user.group_ids) response = {} for ds in data_sources: @@ -66,7 +68,7 @@ def get(self): try: d = ds.to_dict() - d['view_only'] = all(project(ds.groups, self.current_user.groups).values()) + d['view_only'] = all(project(ds.groups, self.current_user.group_ids).values()) response[ds.id] = d except AttributeError: logging.exception("Error with DataSource#to_dict (data source id: %d)", ds.id) @@ -123,7 +125,8 @@ def post(self, data_source_id): reason = request.args.get('reason') data_source.pause(reason) - data_source.save() + models.db.session.add(data_source) + models.db.session.commit() self.record_event({ 'action': 'pause', @@ -137,7 +140,8 @@ def post(self, data_source_id): def delete(self, data_source_id): data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org) data_source.resume() - data_source.save() + models.db.session.add(data_source) + models.db.session.commit() self.record_event({ 'action': 'resume', diff --git a/redash/models.py b/redash/models.py index ae04d877ff..6be48860af 100644 --- a/redash/models.py +++ b/redash/models.py @@ -363,7 +363,7 @@ class DataSource(BelongsToOrgMixin, db.Model): name = Column(db.String(255)) type = Column(db.String(255)) - options = Column(Configuration) + options = Column(ConfigurationContainer.as_mutable(Configuration)) queue_name = Column(db.String(255), default="queries") scheduled_queue_name = Column(db.String(255), default="scheduled_queries") created_at = Column(db.DateTime(True), default=db.func.now()) @@ -463,11 +463,12 @@ def query_runner(self): return get_query_runner(self.type, self.options) @classmethod - def all(cls, org, groups=None): - data_sources = cls.select().where(cls.org==org).order_by(cls.id.asc()) + def all(cls, org, group_ids=None): + data_sources = cls.query.filter(cls.org == org).order_by(cls.id.asc()) - if groups: - data_sources = data_sources.join(DataSourceGroup).where(DataSourceGroup.group << groups) + if group_ids: + data_sources = data_sources.join(DataSourceGroup).filter( + DataSourceGroup.group_id.in_(group_ids)) return data_sources diff --git a/redash/utils/configuration.py b/redash/utils/configuration.py index 3d8802b3fd..615cc162f8 100644 --- a/redash/utils/configuration.py +++ b/redash/utils/configuration.py @@ -2,10 +2,23 @@ import jsonschema from jsonschema import ValidationError +from sqlalchemy.ext.mutable import Mutable + SECRET_PLACEHOLDER = '--------' -class ConfigurationContainer(object): +class ConfigurationContainer(Mutable): + @classmethod + def coerce(cls, key, value): + if not isinstance(value, ConfigurationContainer): + if isinstance(value, dict): + return ConfigurationContainer(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + def __init__(self, config, schema=None): self._config = config self.set_schema(schema) @@ -59,12 +72,14 @@ def update(self, new_config): config[k] = v self._config = config + self.changed() def get(self, *args, **kwargs): return self._config.get(*args, **kwargs) def __setitem__(self, key, value): self._config[key] = value + self.changed() def __getitem__(self, item): if item in self._config: diff --git a/tests/handlers/test_data_sources.py b/tests/handlers/test_data_sources.py index 67147d6273..4e4fe35702 100644 --- a/tests/handlers/test_data_sources.py +++ b/tests/handlers/test_data_sources.py @@ -20,8 +20,7 @@ def test_fails_if_user_doesnt_belong_to_org(self): class TestDataSourceListGet(BaseTestCase): def test_returns_each_data_source_once(self): group = self.factory.create_group() - self.factory.user.groups.append(group.id) - self.factory.user.save() + self.factory.user.group_ids.append(group.id) self.factory.data_source.add_group(group) self.factory.data_source.add_group(self.factory.org.default_group) response = self.make_request("get", "/api/data_sources", user=self.factory.user) @@ -67,7 +66,7 @@ def test_updates_data_source(self): user=admin) self.assertEqual(rv.status_code, 200) - data_source = DataSource.get_by_id(self.factory.data_source.id) + data_source = DataSource.query.get(self.factory.data_source.id) self.assertEqual(data_source.name, new_name) self.assertEqual(data_source.options.to_dict(), new_options) @@ -103,17 +102,17 @@ def test_pauses_data_source(self): admin = self.factory.create_admin() rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(DataSource.get_by_id(self.factory.data_source.id).paused, True) + self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, True) def test_pause_sets_reason(self): admin = self.factory.create_admin() rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin, data={'reason': 'testing'}) self.assertEqual(rv.status_code, 200) - self.assertEqual(DataSource.get_by_id(self.factory.data_source.id).paused, True) - self.assertEqual(DataSource.get_by_id(self.factory.data_source.id).pause_reason, 'testing') + self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, True) + self.assertEqual(DataSource.query.get(self.factory.data_source.id).pause_reason, 'testing') rv = self.make_request('post', '/api/data_sources/{}/pause?reason=test'.format(self.factory.data_source.id), user=admin) - self.assertEqual(DataSource.get_by_id(self.factory.data_source.id).pause_reason, 'test') + self.assertEqual(DataSource.query.get(self.factory.data_source.id).pause_reason, 'test') def test_requires_admin(self): rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id)) @@ -124,10 +123,9 @@ class TestDataSourcePauseDelete(BaseTestCase): def test_resumes_data_source(self): admin = self.factory.create_admin() self.factory.data_source.pause() - self.factory.data_source.save() rv = self.make_request('delete', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(DataSource.get_by_id(self.factory.data_source.id).paused, False) + self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, False) def test_requires_admin(self): rv = self.make_request('delete', '/api/data_sources/{}/pause'.format(self.factory.data_source.id)) From c355eeffb606999e3f25c22ffd3fde196e03dbde Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 11:52:56 +0200 Subject: [PATCH 23/80] Fix test_permissions tests --- tests/test_permissions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 7ed70e5940..b2a65076c4 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -3,7 +3,7 @@ from redash.permissions import has_access -MockUser = namedtuple('MockUser', ['permissions', 'groups']) +MockUser = namedtuple('MockUser', ['permissions', 'group_ids']) view_only = True From 6b2d6a22f5d1f2b0b3c7186c8d909951f5674f65 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 11:54:00 +0200 Subject: [PATCH 24/80] Change groups property of ApiUser to be group_ids --- redash/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redash/models.py b/redash/models.py index 6be48860af..25aa0303a6 100644 --- a/redash/models.py +++ b/redash/models.py @@ -156,7 +156,7 @@ def __init__(self, api_key, org, groups, name=None): self.id = api_key.api_key self.name = "ApiKey: {}".format(api_key.id) self.object = api_key.object - self.groups = groups + self.group_ids = groups self.org = org def __repr__(self): From b61dbfa16b379fde68cb572c7dc966660f30f481 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 14:18:00 +0200 Subject: [PATCH 25/80] Fix test_authentication tests --- tests/test_authentication.py | 66 ++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 5275975384..41bff2fa57 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -23,40 +23,48 @@ def setUp(self): self.queries_url = '/{}/api/queries'.format(self.factory.org.slug) def test_no_api_key(self): - rv = self.client.get(self.query_url) - self.assertIsNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.query_url) + self.assertIsNone(api_key_load_user_from_request(request)) def test_wrong_api_key(self): - rv = self.client.get(self.query_url, query_string={'api_key': 'whatever'}) - self.assertIsNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.query_url, query_string={'api_key': 'whatever'}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_correct_api_key(self): - rv = self.client.get(self.query_url, query_string={'api_key': self.api_key}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.query_url, query_string={'api_key': self.api_key}) + self.assertIsNotNone(api_key_load_user_from_request(request)) def test_no_query_id(self): - rv = self.client.get(self.queries_url, query_string={'api_key': self.api_key}) - self.assertIsNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.queries_url, query_string={'api_key': self.api_key}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") models.db.session.flush() - rv = self.client.get(self.queries_url, query_string={'api_key': user.api_key}) - self.assertEqual(user.id, api_key_load_user_from_request(request).id) + with self.app.test_client() as c: + rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) + self.assertEqual(user.id, api_key_load_user_from_request(request).id) def test_api_key_header(self): - rv = self.client.get(self.query_url, headers={'Authorization': "Key {}".format(self.api_key)}) - self.assertIsNotNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(self.api_key)}) + self.assertIsNotNone(api_key_load_user_from_request(request)) def test_api_key_header_with_wrong_key(self): - rv = self.client.get(self.query_url, headers={'Authorization': "Key oops"}) - self.assertIsNone(api_key_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.query_url, headers={'Authorization': "Key oops"}) + self.assertIsNone(api_key_load_user_from_request(request)) def test_api_key_for_wrong_org(self): other_user = self.factory.create_admin(org=self.factory.create_org()) - rv = self.client.get(self.query_url, headers={'Authorization': "Key {}".format(other_user.api_key)}) - self.assertEqual(404, rv.status_code) + with self.app.test_client() as c: + rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(other_user.api_key)}) + self.assertEqual(404, rv.status_code) class TestHMACAuthentication(BaseTestCase): @@ -75,20 +83,24 @@ def signature(self, expires): return sign(self.query.api_key, self.path, expires) def test_no_signature(self): - rv = self.client.get(self.path) - self.assertIsNone(hmac_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.path) + self.assertIsNone(hmac_load_user_from_request(request)) def test_wrong_signature(self): - rv = self.client.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) - self.assertIsNone(hmac_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) + self.assertIsNone(hmac_load_user_from_request(request)) def test_correct_signature(self): - rv = self.client.get(self.path, query_string={'signature': self.signature(self.expires), 'expires': self.expires}) - self.assertIsNotNone(hmac_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get(self.path, query_string={'signature': self.signature(self.expires), 'expires': self.expires}) + self.assertIsNotNone(hmac_load_user_from_request(request)) def test_no_query_id(self): - rv = self.client.get('/{}/api/queries'.format(self.query.org.slug), query_string={'api_key': self.api_key}) - self.assertIsNone(hmac_load_user_from_request(request)) + with self.app.test_client() as c: + rv = c.get('/{}/api/queries'.format(self.query.org.slug), query_string={'api_key': self.api_key}) + self.assertIsNone(hmac_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") @@ -96,8 +108,9 @@ def test_user_api_key(self): models.db.session.flush() signature = sign(user.api_key, path, self.expires) - rv = self.client.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) - self.assertEqual(user.id, hmac_load_user_from_request(request).id) + with self.app.test_client() as c: + rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) + self.assertEqual(user.id, hmac_load_user_from_request(request).id) class TestCreateAndLoginUser(BaseTestCase): @@ -119,6 +132,7 @@ def test_creates_vaild_new_user(self): user = models.User.query.filter(models.User.email == email).one() self.assertEqual(user.email, email) + class TestVerifyProfile(BaseTestCase): def test_no_domain_allowed_for_org(self): profile = dict(email='arik@example.com') From 419234a23ebdbe6c2a83ab218d94ed0bc6c73a64 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 15:38:53 +0200 Subject: [PATCH 26/80] Fix widget tests --- redash/handlers/widgets.py | 13 ++++++++----- redash/models.py | 4 ++-- tests/handlers/test_widgets.py | 11 ++++++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/redash/handlers/widgets.py b/redash/handlers/widgets.py index 530e217dc8..9aab8023c1 100644 --- a/redash/handlers/widgets.py +++ b/redash/handlers/widgets.py @@ -22,13 +22,15 @@ def post(self): visualization_id = widget_properties.pop('visualization_id') if visualization_id: visualization = models.Visualization.get_by_id_and_org(visualization_id, self.current_org) - require_access(visualization.query.groups, self.current_user, view_only) + require_access(visualization.query_rel.groups, self.current_user, view_only) else: visualization = None widget_properties['visualization'] = visualization - widget = models.Widget.create(**widget_properties) + widget = models.Widget(**widget_properties) + models.db.session.add(widget) + models.db.session.commit() layout = json.loads(widget.dashboard.layout) new_row = True @@ -36,7 +38,7 @@ def post(self): if len(layout) == 0 or widget.width == 2: layout.append([widget.id]) elif len(layout[-1]) == 1: - neighbour_widget = models.Widget.get(models.Widget.id == layout[-1][0]) + neighbour_widget = models.Widget.query.get(layout[-1][0]) if neighbour_widget.width == 1: layout[-1].append(widget.id) new_row = False @@ -46,7 +48,7 @@ def post(self): layout.append([widget.id]) widget.dashboard.layout = json.dumps(layout) - widget.dashboard.save() + models.db.session.add(widget.dashboard) return {'widget': widget.to_dict(), 'layout': layout, 'new_row': new_row, 'version': dashboard.version} @@ -67,6 +69,7 @@ def post(self, widget_id): def delete(self, widget_id): widget = models.Widget.get_by_id_and_org(widget_id, self.current_org) require_object_modify_permission(widget.dashboard, self.current_user) - widget.delete_instance() + + models.db.session.delete(widget) return {'layout': widget.dashboard.layout, 'version': widget.dashboard.version} diff --git a/redash/models.py b/redash/models.py index 25aa0303a6..ed006fe33a 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1068,6 +1068,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) + widgets = db.relationship('Widget', backref='dashboard', lazy='dynamic') __tablename__ = 'dashboards' __mapper_args__ = { @@ -1232,7 +1233,6 @@ class Widget(TimestampMixin, db.Model): width = Column(db.Integer) options = Column(db.Text) dashboard_id = Column(db.Integer, db.ForeignKey("dashboards.id"), index=True) - dashboard = db.relationship(Dashboard) # unused; kept for backward compatability: type = Column(db.String(100), nullable=True) @@ -1261,7 +1261,7 @@ def __unicode__(self): @classmethod def get_by_id_and_org(cls, widget_id, org): - return cls.select(cls, Dashboard).join(Dashboard).where(cls.id == widget_id, Dashboard.org == org).get() + return db.session.query(cls).join(Dashboard).filter(cls.id == widget_id, Dashboard.org== org).one() #XXX produces SQLA warning, replace with association table @listens_for(Widget, 'before_delete') diff --git a/tests/handlers/test_widgets.py b/tests/handlers/test_widgets.py index 36923faf4e..d2ce8f6471 100644 --- a/tests/handlers/test_widgets.py +++ b/tests/handlers/test_widgets.py @@ -22,15 +22,15 @@ def test_create_widget(self): rv = self.create_widget(dashboard, vis) self.assertEquals(rv.status_code, 200) - dashboard = models.Dashboard.get(models.Dashboard.id == dashboard.id) + dashboard = models.Dashboard.query.get(dashboard.id) self.assertEquals(unicode(rv.json['layout']), dashboard.layout) - self.assertEquals(dashboard.widgets, 1) + self.assertEquals(dashboard.widgets.count(), 1) self.assertEquals(rv.json['layout'], [[rv.json['widget']['id']]]) self.assertEquals(rv.json['new_row'], True) rv2 = self.create_widget(dashboard, vis) - self.assertEquals(dashboard.widgets, 2) + self.assertEquals(dashboard.widgets.count(), 2) self.assertEquals(rv2.json['layout'], [[rv.json['widget']['id'], rv2.json['widget']['id']]]) self.assertEquals(rv2.json['new_row'], False) @@ -48,8 +48,9 @@ def test_wont_create_widget_for_visualization_you_dont_have_access_to(self): dashboard = self.factory.create_dashboard() vis = self.factory.create_visualization() ds = self.factory.create_data_source(group=self.factory.create_group()) - vis.query.data_source = ds - vis.query.save() + vis.query_rel.data_source = ds + + models.db.session.add(vis.query_rel) data = { 'visualization_id': vis.id, From f4c76527ee1535dac9c2d2c225566fd77eae651b Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 16:02:27 +0200 Subject: [PATCH 27/80] Fix refresh queries tests --- redash/models.py | 6 +-- redash/tasks/base.py | 11 ++++-- redash/tasks/queries.py | 2 +- tests/factories.py | 2 +- tests/tasks/test_refresh_queries.py | 57 ++++++++++++++--------------- 5 files changed, 40 insertions(+), 38 deletions(-) diff --git a/redash/models.py b/redash/models.py index ed006fe33a..82e6fb692b 100644 --- a/redash/models.py +++ b/redash/models.py @@ -499,7 +499,7 @@ class QueryResult(db.Model, BelongsToOrgMixin): data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) data_source = db.relationship(DataSource) query_hash = Column(db.String(32), index=True) - query = Column(db.Text) + query_text = Column(db.Text) data = Column(db.Text) runtime = Column(postgresql.DOUBLE_PRECISION) retrieved_at = Column(db.DateTime(True)) @@ -510,7 +510,7 @@ def to_dict(self): return { 'id': self.id, 'query_hash': self.query_hash, - 'query': self.query, + 'query': self.query_text, 'data': json.loads(self.data), 'data_source_id': self.data_source_id, 'runtime': self.runtime, @@ -551,7 +551,7 @@ def get_latest(cls, data_source, query, max_age=0): def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): query_result = cls(org=org, query_hash=query_hash, - query=query, + query_text=query, runtime=run_time, data_source=data_source, retrieved_at=retrieved_at, diff --git a/redash/tasks/base.py b/redash/tasks/base.py index c9dcbce9e5..81ac989dda 100644 --- a/redash/tasks/base.py +++ b/redash/tasks/base.py @@ -1,13 +1,18 @@ from celery import Task -from redash import models +from redash import create_app +from flask import has_app_context, current_app class BaseTask(Task): abstract = True def after_return(self, *args, **kwargs): - models.db.close_db(None) + if hasattr(self, 'app_ctx'): + self.app_ctx.pop() def __call__(self, *args, **kwargs): - models.db.connect_db() + if not has_app_context(): + flask_app = current_app or create_app() + self.app_ctx = flask_app.app_context() + self.app_ctx.push() return super(BaseTask, self).__call__(*args, **kwargs) diff --git a/redash/tasks/queries.py b/redash/tasks/queries.py index afaebe10d6..83b7267c96 100644 --- a/redash/tasks/queries.py +++ b/redash/tasks/queries.py @@ -262,7 +262,7 @@ def refresh_queries(): elif query.data_source.paused: logging.info("Skipping refresh of %s because datasource - %s is paused (%s).", query.id, query.data_source.name, query.data_source.pause_reason) else: - enqueue_query(query.query, query.data_source, query.user_id, + enqueue_query(query.query_text, query.data_source, query.user_id, scheduled=True, metadata={'Query ID': query.id, 'Username': 'Scheduled'}) diff --git a/tests/factories.py b/tests/factories.py index c5b4645a31..29de2fb58b 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -105,7 +105,7 @@ def __call__(self): data='{"columns":{}, "rows":[]}', runtime=1, retrieved_at=utcnow, - query="SELECT 1", + query_text="SELECT 1", query_hash=gen_query_hash('SELECT 1'), data_source=data_source_factory.create, org_id=1) diff --git a/tests/tasks/test_refresh_queries.py b/tests/tasks/test_refresh_queries.py index 8927f9c928..03dbfdd888 100644 --- a/tests/tasks/test_refresh_queries.py +++ b/tests/tasks/test_refresh_queries.py @@ -3,6 +3,7 @@ from tests import BaseTestCase from redash.utils import utcnow from redash.tasks import refresh_queries +from redash.models import db # TODO: this test should be split into two: @@ -12,22 +13,22 @@ class TestRefreshQueries(BaseTestCase): def test_enqueues_outdated_queries(self): query = self.factory.create_query(schedule="60") retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, query_hash=query.query_hash) query.latest_query_data = query_result - query.save() + db.session.add(query) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() - add_job_mock.assert_called_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY) + add_job_mock.assert_called_with(query.query_text, query.data_source, query.user_id, scheduled=True, metadata=ANY) def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self): query = self.factory.create_query(schedule="60") retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, query_hash=query.query_hash) query.latest_query_data = query_result - query.save() + db.session.add(query) query.data_source.pause() @@ -39,13 +40,13 @@ def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self): with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() - add_job_mock.assert_called_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY) + add_job_mock.assert_called_with(query.query_text, query.data_source, query.user_id, scheduled=True, metadata=ANY) def test_skips_fresh_queries(self): query = self.factory.create_query(schedule="1200") retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, - query_hash=query.query_hash) + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, + query_hash=query.query_hash) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() @@ -54,8 +55,8 @@ def test_skips_fresh_queries(self): def test_skips_queries_with_no_ttl(self): query = self.factory.create_query(schedule=None) retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, - query_hash=query.query_hash) + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, + query_hash=query.query_hash) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() @@ -63,49 +64,45 @@ def test_skips_queries_with_no_ttl(self): def test_enqueues_query_only_once(self): query = self.factory.create_query(schedule="60") - query2 = self.factory.create_query(schedule="60", query=query.query, query_hash=query.query_hash) + query2 = self.factory.create_query(schedule="60", query_text=query.query_text, query_hash=query.query_hash) retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, - query_hash=query.query_hash) + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, + query_hash=query.query_hash) query.latest_query_data = query_result query2.latest_query_data = query_result - query.save() - query2.save() + db.session.add_all([query, query2]) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() - add_job_mock.assert_called_once_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)#{'Query ID': query.id, 'Username': 'Scheduled'}) + add_job_mock.assert_called_once_with(query.query_text, query.data_source, query.user_id, scheduled=True, metadata=ANY)#{'Query ID': query.id, 'Username': 'Scheduled'}) def test_enqueues_query_with_correct_data_source(self): query = self.factory.create_query(schedule="60", data_source=self.factory.create_data_source()) - query2 = self.factory.create_query(schedule="60", query=query.query, query_hash=query.query_hash) + query2 = self.factory.create_query(schedule="60", query_text=query.query_text, query_hash=query.query_hash) retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, - query_hash=query.query_hash) + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, + query_hash=query.query_hash) query.latest_query_data = query_result query2.latest_query_data = query_result - query.save() - query2.save() + db.session.add_all([query, query2]) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() - add_job_mock.assert_has_calls([call(query2.query, query2.data_source, query2.user_id, scheduled=True, metadata=ANY), - call(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)], + add_job_mock.assert_has_calls([call(query2.query_text, query2.data_source, query2.user_id, scheduled=True, metadata=ANY), + call(query.query_text, query.data_source, query.user_id, scheduled=True, metadata=ANY)], any_order=True) self.assertEquals(2, add_job_mock.call_count) def test_enqueues_only_for_relevant_data_source(self): query = self.factory.create_query(schedule="60") - query2 = self.factory.create_query(schedule="3600", query=query.query, query_hash=query.query_hash) - import psycopg2 + query2 = self.factory.create_query(schedule="3600", query_text=query.query_text, query_hash=query.query_hash) retrieved_at = utcnow() - datetime.timedelta(minutes=10) - query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, - query_hash=query.query_hash) + query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query_text=query.query_text, + query_hash=query.query_hash) query.latest_query_data = query_result query2.latest_query_data = query_result - query.save() - query2.save() + db.session.add_all([query, query2]) with patch('redash.tasks.queries.enqueue_query') as add_job_mock: refresh_queries() - add_job_mock.assert_called_once_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY) + add_job_mock.assert_called_once_with(query.query_text, query.data_source, query.user_id, scheduled=True, metadata=ANY) From 4459c464ca1ee391bc2f5f3c2e001b9c736ad183 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 16:12:53 +0200 Subject: [PATCH 28/80] use Query#query_text instead of Query#query --- tests/tasks/test_queries.py | 12 ++++++------ tests/test_models.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/tasks/test_queries.py b/tests/tasks/test_queries.py index 1df8e1c850..198c2ea086 100644 --- a/tests/tasks/test_queries.py +++ b/tests/tasks/test_queries.py @@ -47,9 +47,9 @@ def test_multiple_enqueue_of_same_query(self): query = self.factory.create_query() execute_query.apply_async = MagicMock(side_effect=gen_hash) - enqueue_query(query.query, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) self.assertEqual(1, execute_query.apply_async.call_count) self.assertEqual(1, redis_connection.zcard(QueryTaskTracker.WAITING_LIST)) @@ -60,9 +60,9 @@ def test_multiple_enqueue_of_different_query(self): query = self.factory.create_query() execute_query.apply_async = MagicMock(side_effect=gen_hash) - enqueue_query(query.query, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query + '2', query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query + '3', query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text, query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text + '2', query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query(query.query_text + '3', query.data_source, True, {'Username': 'Arik', 'Query ID': query.id}) self.assertEqual(3, execute_query.apply_async.call_count) self.assertEqual(3, redis_connection.zcard(QueryTaskTracker.WAITING_LIST)) diff --git a/tests/test_models.py b/tests/test_models.py index 8f58646c35..846773523b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -346,14 +346,14 @@ def test_get_latest_returns_none_if_not_found(self): def test_get_latest_returns_when_found(self): qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, 60) + found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, 60) self.assertEqual(qr, found_query_result) def test_get_latest_doesnt_return_query_from_different_data_source(self): qr = self.factory.create_query_result() data_source = self.factory.create_data_source() - found_query_result = models.QueryResult.get_latest(data_source, qr.query, 60) + found_query_result = models.QueryResult.get_latest(data_source, qr.query_text, 60) self.assertIsNone(found_query_result) @@ -361,7 +361,7 @@ def test_get_latest_doesnt_return_if_ttl_expired(self): yesterday = utcnow() - datetime.timedelta(days=1) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=60) + found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, max_age=60) self.assertIsNone(found_query_result) @@ -369,7 +369,7 @@ def test_get_latest_returns_if_ttl_not_expired(self): yesterday = utcnow() - datetime.timedelta(seconds=30) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=120) + found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, max_age=120) self.assertEqual(found_query_result, qr) @@ -378,7 +378,7 @@ def test_get_latest_returns_the_most_recent_result(self): old_qr = self.factory.create_query_result(retrieved_at=yesterday) qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, 60) + found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, 60) self.assertEqual(found_query_result.id, qr.id) @@ -388,7 +388,7 @@ def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): yesterday = utcnow() + datetime.timedelta(days=-1) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, -1) + found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, -1) self.assertEqual(found_query_result.id, qr.id) @@ -492,7 +492,7 @@ def test_stores_the_result(self): self.assertEqual(query_result.data, self.data) self.assertEqual(query_result.runtime, self.runtime) self.assertEqual(query_result.retrieved_at, self.utcnow) - self.assertEqual(query_result.query, self.query) + self.assertEqual(query_result.query_text, self.query) self.assertEqual(query_result.query_hash, self.query_hash) self.assertEqual(query_result.data_source, self.data_source) From 9c1450f4c92a2921bf73973c2fee1709f314a5c5 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 16:24:59 +0200 Subject: [PATCH 29/80] Fix users handlers tests --- redash/handlers/users.py | 5 +++-- redash/models.py | 2 +- redash/tasks/queries.py | 2 +- tests/handlers/test_users.py | 12 ++++++++++-- tests/models/test_changes.py | 10 +++++----- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/redash/handlers/users.py b/redash/handlers/users.py index aa0bac81d3..d79fe27adf 100644 --- a/redash/handlers/users.py +++ b/redash/handlers/users.py @@ -34,7 +34,8 @@ def post(self): group_ids=[self.current_org.default_group.id]) try: - user.save() + models.db.session.add(user) + models.db.session.commit() except IntegrityError as e: if "email" in e.message: abort(400, message='Email already taken.') @@ -108,7 +109,7 @@ def post(self, user_id): abort(403, message="Must be admin to change groups membership.") try: - user.update_instance(**params) + self.update_model(user, params) except IntegrityError as e: if "email" in e.message: message = "Email already taken." diff --git a/redash/models.py b/redash/models.py index 82e6fb692b..c7af685563 100644 --- a/redash/models.py +++ b/redash/models.py @@ -320,7 +320,7 @@ def get_by_api_key_and_org(cls, api_key, org): @classmethod def all(cls, org): - return cls.select().where(cls.org == org) + return cls.query.filter(cls.org == org) @classmethod def find_by_email(cls, email): diff --git a/redash/tasks/queries.py b/redash/tasks/queries.py index 83b7267c96..1174e480e5 100644 --- a/redash/tasks/queries.py +++ b/redash/tasks/queries.py @@ -349,7 +349,7 @@ def refresh_schemas(): logger.info(u"task=refresh_schemas state=start") - for ds in models.DataSource.select(): + for ds in models.DataSource.query: if ds.paused: logger.info(u"task=refresh_schema state=skip ds_id=%s reason=paused(%s)", ds.id, ds.pause_reason) elif ds.id in blacklist: diff --git a/tests/handlers/test_users.py b/tests/handlers/test_users.py index aed7a93743..1d0a67941c 100644 --- a/tests/handlers/test_users.py +++ b/tests/handlers/test_users.py @@ -26,6 +26,14 @@ def test_creates_user(self): self.assertEqual(rv.json['name'], test_user['name']) self.assertEqual(rv.json['email'], test_user['email']) + def test_returns_400_when_email_taken(self): + admin = self.factory.create_admin() + + test_user = {'name': 'User', 'email': admin.email, 'password': 'test'} + rv = self.make_request('post', '/api/users', data=test_user, user=admin) + + self.assertEqual(rv.status_code, 400) + class TestUserListGet(BaseTestCase): def test_returns_users_for_given_org_only(self): @@ -97,10 +105,10 @@ def test_changes_password(self): old_password = "old password" self.factory.user.hash_password(old_password) - self.factory.user.save() + models.db.session.add(self.factory.user) rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"password": new_password, "old_password": old_password}) self.assertEqual(rv.status_code, 200) - user = models.User.get_by_id(self.factory.user.id) + user = models.User.query.get(self.factory.user.id) self.assertTrue(user.verify_password(new_password)) diff --git a/tests/models/test_changes.py b/tests/models/test_changes.py index 549ddc87af..683aa01d95 100644 --- a/tests/models/test_changes.py +++ b/tests/models/test_changes.py @@ -1,12 +1,12 @@ from tests import BaseTestCase -from redash.models import Query, Change, ChangeTrackingMixin +from redash.models import db, Query, Change, ChangeTrackingMixin def create_object(factory): obj = Query(name='Query', description='', - query='SELECT 1', + query_text='SELECT 1', user=factory.user, data_source=factory.data_source, org=factory.org) @@ -23,7 +23,7 @@ def test_returns_initial_state(self): def test_returns_no_changes_after_save(self): obj = create_object(self.factory) - obj.save() + db.session.add(obj) self.assertEqual({}, obj.changes) @@ -32,7 +32,7 @@ class TestLogChange(BaseTestCase): def obj(self): obj = Query(name='Query', description='', - query='SELECT 1', + query_text='SELECT 1', user=self.factory.user, data_source=self.factory.data_source, org=self.factory.org) @@ -72,7 +72,7 @@ def test_properly_log_modification(self): self.assertIn('description', change.change) def test_logs_create_method(self): - q = Query.create(name='Query', description='', query='', user=self.factory.user, + q = Query.create(name='Query', description='', query_text='', user=self.factory.user, data_source=self.factory.data_source, org=self.factory.org) change = Change.last_change(q) From 802b812932684f4c37c090bc9ca14c2519524ca6 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 16:52:30 +0200 Subject: [PATCH 30/80] Fix query results tests --- redash/handlers/query_results.py | 2 +- tests/factories.py | 2 +- tests/handlers/test_query_results.py | 11 +++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/redash/handlers/query_results.py b/redash/handlers/query_results.py index 6f9c2a6595..015dc24094 100644 --- a/redash/handlers/query_results.py +++ b/redash/handlers/query_results.py @@ -111,7 +111,7 @@ def get(self, query_id=None, query_result_id=None, filetype='json'): if query_result_id is None and query_id is not None: query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) if query: - query_result_id = query._data['latest_query_data'] + query_result_id = query.latest_query_data_id if query_result_id: query_result = get_object_or_404(models.QueryResult.get_by_id_and_org, query_result_id, self.current_org) diff --git a/tests/factories.py b/tests/factories.py index 29de2fb58b..210bfc357f 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -251,10 +251,10 @@ def create_data_source(self, **kwargs): if group and 'org' not in kwargs: args['org'] = group.org + view_only = args.pop('view_only', False) data_source = data_source_factory.create(**args) if group: - view_only = kwargs.pop('view_only', False) db.session.add(redash.models.DataSourceGroup( group=group, data_source=data_source, diff --git a/tests/handlers/test_query_results.py b/tests/handlers/test_query_results.py index bc86a29daf..00f018a543 100644 --- a/tests/handlers/test_query_results.py +++ b/tests/handlers/test_query_results.py @@ -1,5 +1,6 @@ import json from tests import BaseTestCase +from redash.models import db class TestQueryResultsCacheHeaders(BaseTestCase): @@ -31,7 +32,7 @@ def test_get_existing_result(self): rv = self.make_request('post', '/api/query_results', data={'data_source_id': self.factory.data_source.id, - 'query': query.query}) + 'query': query.query_text}) self.assertEquals(rv.status_code, 200) self.assertEquals(query_result.id, rv.json['query_result']['id']) @@ -41,7 +42,7 @@ def test_execute_new_query(self): rv = self.make_request('post', '/api/query_results', data={'data_source_id': self.factory.data_source.id, - 'query': query.query, + 'query': query.query_text, 'max_age': 0}) self.assertEquals(rv.status_code, 200) @@ -49,12 +50,14 @@ def test_execute_new_query(self): self.assertIn('job', rv.json) def test_execute_query_without_access(self): - user = self.factory.create_user(groups=[self.factory.create_group().id]) + group = self.factory.create_group() + db.session.commit() + user = self.factory.create_user(group_ids=[group.id]) query = self.factory.create_query() rv = self.make_request('post', '/api/query_results', data={'data_source_id': self.factory.data_source.id, - 'query': query.query, + 'query': query.query_text, 'max_age': 0}, user=user) From c5548e937502814de3e3525f2881357c4f990c01 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 17:39:33 +0200 Subject: [PATCH 31/80] Fix query handlers test --- redash/handlers/queries.py | 10 +++++++--- redash/models.py | 1 - tests/handlers/test_queries.py | 23 +++++++++++++---------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index 1a781f94eb..1164602011 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -60,11 +60,14 @@ def post(self): if 'latest_query_data_id' in query_def: query_def['latest_query_data'] = query_def.pop('latest_query_data_id') + query_def['query_text'] = query_def.pop('query') query_def['user'] = self.current_user query_def['data_source'] = data_source query_def['org'] = self.current_org query_def['is_draft'] = True - query = models.Query.create(**query_def) + query = models.Query(**query_def) + models.db.session.add(query) + models.db.session.commit() self.record_event({ 'action': 'create', @@ -114,7 +117,8 @@ def post(self, query_id): query_def['changed_by'] = self.current_user try: - query.update_instance(**query_def) + self.update_model(query, query_def) + models.db.session.commit() except models.ConflictDetectedError: abort(409) @@ -152,4 +156,4 @@ def post(self, query_id): parameter_values = collect_parameters_from_request(request.args) - return run_query(query.data_source, parameter_values, query.query, query.id) + return run_query(query.data_source, parameter_values, query.query_text, query.id) diff --git a/redash/models.py b/redash/models.py index c7af685563..3fd0d95bb2 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1280,7 +1280,6 @@ class Event(db.Model): user_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) user = db.relationship(User, backref="events") action = Column(db.String(255)) - # XXX replace with association table object_type = Column(db.String(255)) object_id = Column(db.String(255), nullable=True) additional_properties = Column(db.Text, nullable=True) diff --git a/tests/handlers/test_queries.py b/tests/handlers/test_queries.py index ce7e24484c..368eb41b6d 100644 --- a/tests/handlers/test_queries.py +++ b/tests/handlers/test_queries.py @@ -1,5 +1,6 @@ from tests import BaseTestCase from redash import models +from redash.models import db from redash.permissions import ACCESS_TYPE_MODIFY @@ -26,7 +27,7 @@ def test_get_all_queries(self): def test_query_without_data_source_should_be_available_only_by_admin(self): query = self.factory.create_query() query.data_source = None - query.save() + db.session.add(query) rv = self.make_request('get', '/api/queries/{}'.format(query.id)) self.assertEquals(rv.status_code, 403) @@ -40,7 +41,7 @@ def test_query_only_accessible_to_users_from_its_organization(self): query = self.factory.create_query() query.data_source = None - query.save() + db.session.add(query) rv = self.make_request('get', '/api/queries/{}'.format(query.id), user=second_org_admin) self.assertEquals(rv.status_code, 404) @@ -62,7 +63,7 @@ def test_update_query(self): def test_raises_error_in_case_of_conflict(self): q = self.factory.create_query() q.name = "Another Name" - q.save() + db.session.add(q) rv = self.make_request('post', '/api/queries/{0}'.format(q.id), data={'name': 'Testing', 'version': q.version - 1}, user=self.factory.user) self.assertEqual(rv.status_code, 409) @@ -70,7 +71,7 @@ def test_raises_error_in_case_of_conflict(self): def test_overrides_existing_if_no_version_specified(self): q = self.factory.create_query() q.name = "Another Name" - q.save() + db.session.add(q) rv = self.make_request('post', '/api/queries/{0}'.format(q.id), data={'name': 'Testing'}, user=self.factory.user) self.assertEqual(rv.status_code, 200) @@ -107,7 +108,7 @@ def test_create_query(self): self.assertIsNotNone(rv.json['api_key']) self.assertIsNotNone(rv.json['query_hash']) - query = models.Query.get_by_id(rv.json['id']) + query = models.Query.query.get(rv.json['id']) self.assertEquals(len(list(query.visualizations)), 1) self.assertTrue(query.is_draft) @@ -124,21 +125,23 @@ def test_refresh_regular_query(self): self.assertEqual(200, response.status_code) def test_refresh_of_query_with_parameters(self): - self.query.query = "SELECT {{param}}" - self.query.save() + self.query.query_text = u"SELECT {{param}}" + db.session.add(self.query) response = self.make_request('post', "{}?p_param=1".format(self.path)) self.assertEqual(200, response.status_code) def test_refresh_of_query_with_parameters_without_parameters(self): - self.query.query = "SELECT {{param}}" - self.query.save() + self.query.query_text = u"SELECT {{param}}" + db.session.add(self.query) response = self.make_request('post', "{}".format(self.path)) self.assertEqual(400, response.status_code) def test_refresh_query_you_dont_have_access_to(self): group = self.factory.create_group() - user = self.factory.create_user(groups=[group.id]) + db.session.add(group) + db.session.commit() + user = self.factory.create_user(group_ids=[group.id]) response = self.make_request('post', self.path, user=user) self.assertEqual(403, response.status_code) From d103e3f7bf403b5e99e5b4b44551f03d518f8dab Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 30 Nov 2016 23:15:00 +0200 Subject: [PATCH 32/80] Fix groups tests (except for delete) --- redash/handlers/groups.py | 3 ++- tests/handlers/test_embed.py | 8 ++++---- tests/handlers/test_groups.py | 8 +++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/redash/handlers/groups.py b/redash/handlers/groups.py index 61025fd5f7..9eb72df89f 100644 --- a/redash/handlers/groups.py +++ b/redash/handlers/groups.py @@ -64,7 +64,8 @@ def delete(self, group_id): if group.type == models.Group.BUILTIN_GROUP: abort(400, message="Can't delete built-in groups.") - group.delete_instance(recursive=True) + models.db.session.delete(group) + models.db.session.commit() class GroupMemberListResource(BaseResource): diff --git a/tests/handlers/test_embed.py b/tests/handlers/test_embed.py index 3185af3a91..6ae8830945 100644 --- a/tests/handlers/test_embed.py +++ b/tests/handlers/test_embed.py @@ -1,14 +1,14 @@ -from redash import settings from tests import BaseTestCase +from redash.models import db class TestEmbedVisualization(BaseTestCase): def test_sucesss(self): vis = self.factory.create_visualization() - vis.query.latest_query_data = self.factory.create_query_result() - vis.query.save() + vis.query_rel.latest_query_data = self.factory.create_query_result() + db.session.add(vis.query_rel) - res = self.make_request("get", "/embed/query/{}/visualization/{}".format(vis.query.id, vis.id), is_json=False) + res = self.make_request("get", "/embed/query/{}/visualization/{}".format(vis.query_rel.id, vis.id), is_json=False) self.assertEqual(res.status_code, 200) # TODO: bring back? diff --git a/tests/handlers/test_groups.py b/tests/handlers/test_groups.py index f07028b849..38ee2113a6 100644 --- a/tests/handlers/test_groups.py +++ b/tests/handlers/test_groups.py @@ -1,6 +1,5 @@ from tests import BaseTestCase -from tests.factories import org_factory -from redash.models import Group, DataSource +from redash.models import Group, DataSource, NoResultFound class TestGroupDataSourceListResource(BaseTestCase): @@ -21,7 +20,7 @@ def test_doesnt_change_builtin_groups(self): data={'name': 'Another Name'}) self.assertEqual(response.status_code, 400) - self.assertEqual(current_name, Group.get_by_id(self.factory.default_group.id).name) + self.assertEqual(current_name, Group.query.get(self.factory.default_group.id).name) class TestGroupResourceDelete(BaseTestCase): @@ -33,8 +32,7 @@ def test_allowed_only_to_admin(self): response = self.make_request('delete', '/api/groups/{}'.format(group.id), user=self.factory.create_admin()) self.assertEqual(response.status_code, 200) - - self.assertRaises(Group.DoesNotExist, Group.get_by_id, group.id) + self.assertIsNone(Group.query.get(group.id)) def test_cant_delete_builtin_group(self): for group in [self.factory.default_group, self.factory.admin_group]: From 29cdfcd7a1a2125a334996aba6650e11d8834408 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 12:08:02 -0600 Subject: [PATCH 33/80] test_embed, test_groups --- redash/models.py | 7 +++++-- tests/handlers/test_groups.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/redash/models.py b/redash/models.py index 3fd0d95bb2..a91703bb68 100644 --- a/redash/models.py +++ b/redash/models.py @@ -218,6 +218,8 @@ class Group(db.Model, BelongsToOrgMixin): REGULAR_GROUP = 'regular' id = Column(db.Integer, primary_key=True) + data_sources = db.relationship("DataSourceGroup", back_populates="group", + cascade="all") org_id = Column(db.Integer, db.ForeignKey('organizations.id')) org = db.relationship(Organization, back_populates="groups") type = Column(db.String(255), default=REGULAR_GROUP) @@ -368,7 +370,8 @@ class DataSource(BelongsToOrgMixin, db.Model): scheduled_queue_name = Column(db.String(255), default="scheduled_queries") created_at = Column(db.DateTime(True), default=db.func.now()) - data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source") + data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source", + cascade="all") __tablename__ = 'data_sources' __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) @@ -486,7 +489,7 @@ class DataSourceGroup(db.Model): data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) data_source = db.relationship(DataSource, back_populates="data_source_groups") group_id = Column(db.Integer, db.ForeignKey("groups.id")) - group = db.relationship(Group, backref="data_sources") + group = db.relationship(Group, back_populates="data_sources") view_only = Column(db.Boolean, default=False) __tablename__ = "data_source_groups" diff --git a/tests/handlers/test_groups.py b/tests/handlers/test_groups.py index 38ee2113a6..5978869bc4 100644 --- a/tests/handlers/test_groups.py +++ b/tests/handlers/test_groups.py @@ -47,4 +47,4 @@ def test_can_delete_group_with_data_sources(self): self.assertEqual(response.status_code, 200) - self.assertEqual(data_source, DataSource.get_by_id(data_source.id)) + self.assertEqual(data_source, DataSource.query.get(data_source.id)) From 9f789d3018ca0bb949d7d2eaf7b149239d3b7d97 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 12:11:14 -0600 Subject: [PATCH 34/80] test_paginate --- tests/handlers/test_paginate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/handlers/test_paginate.py b/tests/handlers/test_paginate.py index 522d0567fe..b251af9ba0 100644 --- a/tests/handlers/test_paginate.py +++ b/tests/handlers/test_paginate.py @@ -4,8 +4,10 @@ from unittest import TestCase from mock import MagicMock -dummy_results = [i for i in range(25)] +class DummyResults(object): + items = [i for i in range(25)] +dummy_results = DummyResults() class TestPaginate(TestCase): def setUp(self): @@ -18,7 +20,7 @@ def test_returns_paginated_results(self): self.assertEqual(page['page'], 1) self.assertEqual(page['page_size'], 25) self.assertEqual(page['count'], 102) - self.assertEqual(page['results'], dummy_results) + self.assertEqual(page['results'], dummy_results.items) def test_raises_error_for_bad_page(self): self.assertRaises(BadRequest, lambda: paginate(self.query_set, -1, 25, lambda x: x)) From 261b374924b2a9aa24f52f18a00bdf27b557a688 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 12:26:02 -0600 Subject: [PATCH 35/80] test_permissions --- redash/handlers/permissions.py | 21 +++++++++++++-------- redash/models.py | 22 +++++++++++----------- tests/handlers/test_permissions.py | 6 +++--- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/redash/handlers/permissions.py b/redash/handlers/permissions.py index 78a5eba24d..295d19819b 100644 --- a/redash/handlers/permissions.py +++ b/redash/handlers/permissions.py @@ -1,10 +1,11 @@ from collections import defaultdict from redash.handlers.base import BaseResource, get_object_or_404 -from redash.models import AccessPermission, Query, Dashboard, User +from redash.models import AccessPermission, Query, Dashboard, User, db from redash.permissions import require_admin_or_owner, ACCESS_TYPES from flask import request from flask_restful import abort +from sqlalchemy.orm.exc import NoResultFound model_to_types = { @@ -50,7 +51,7 @@ def post(self, object_type, object_id): try: grantee = User.get_by_id_and_org(req['user_id'], self.current_org) - except User.DoesNotExist: + except NoResultFound: abort(400, message='User not found.') permission = AccessPermission.grant(obj, access_type, grantee, self.current_user) @@ -67,30 +68,34 @@ def post(self, object_type, object_id): def delete(self, object_type, object_id): model = get_model_from_type(object_type) - obj = get_object_or_404(model.get_by_id_and_org, object_id, self.current_org) + obj = get_object_or_404(model.get_by_id_and_org, object_id, + self.current_org) require_admin_or_owner(obj.user_id) req = request.get_json(True) - grantee = req['user_id'] + grantee_id = req['user_id'] access_type = req['access_type'] - AccessPermission.revoke(obj, grantee, access_type) + AccessPermission.revoke(obj, grantee_id, access_type) self.record_event({ 'action': 'revoke_permission', 'object_id': object_id, 'object_type': object_type, 'access_type': access_type, - 'grantee': grantee + 'grantee_id': grantee_id }) + db.session.commit() class CheckPermissionResource(BaseResource): def get(self, object_type, object_id, access_type): model = get_model_from_type(object_type) - obj = get_object_or_404(model.get_by_id_and_org, object_id, self.current_org) + obj = get_object_or_404(model.get_by_id_and_org, object_id, + self.current_org) - has_access = AccessPermission.exists(obj, access_type, self.current_user) + has_access = AccessPermission.exists(obj, access_type, + self.current_user.id) return {'response': has_access} diff --git a/redash/models.py b/redash/models.py index a91703bb68..5a38cea37c 100644 --- a/redash/models.py +++ b/redash/models.py @@ -892,31 +892,31 @@ def grant(cls, obj, access_type, grantee, grantor): return grant @classmethod - def revoke(cls, obj, grantee, access_type=None): - permissions = cls._query(obj, access_type, grantee) + def revoke(cls, obj, grantee_id, access_type=None): + permissions = cls._query(obj, access_type, grantee_id) return permissions.delete() @classmethod - def find(cls, obj, access_type=None, grantee=None, grantor=None): - return cls._query(obj, access_type, grantee, grantor) + def find(cls, obj, access_type=None, grantee_id=None, grantor_id=None): + return cls._query(obj, access_type, grantee_id, grantor_id) @classmethod - def exists(cls, obj, access_type, grantee): - return cls.find(obj, access_type, grantee).count() > 0 + def exists(cls, obj, access_type, grantee_id): + return cls.find(obj, access_type, grantee_id).count() > 0 @classmethod - def _query(cls, obj, access_type=None, grantee=None, grantor=None): + def _query(cls, obj, access_type=None, grantee_id=None, grantor_id=None): q = cls.query.filter(cls.object_id == obj.id, cls.object_type == obj.__tablename__) if access_type: q.filter(AccessPermission.access_type == access_type) - if grantee: - q.filter(AccessPermission.grantee_id == grantee.id) + if grantee_id: + q.filter(AccessPermission.grantee_id == grantee_id) - if grantor: - q.filter(AccessPermission.grantor_id == grantor.id) + if grantor_id: + q.filter(AccessPermission.grantor_id == grantor_id) return q diff --git a/tests/handlers/test_permissions.py b/tests/handlers/test_permissions.py index 6910e167c9..ef1a0a76d3 100644 --- a/tests/handlers/test_permissions.py +++ b/tests/handlers/test_permissions.py @@ -47,7 +47,7 @@ def test_creates_permission_if_the_user_is_an_owner(self): rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) self.assertEqual(200, rv.status_code) - self.assertTrue(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) + self.assertTrue(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) def test_returns_403_if_the_user_isnt_owner(self): query = self.factory.create_query() @@ -116,7 +116,7 @@ def test_removes_permission(self): self.assertEqual(rv.status_code, 200) - self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) + self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) def test_removes_permission_created_by_another_user(self): query = self.factory.create_query() @@ -134,7 +134,7 @@ def test_removes_permission_created_by_another_user(self): self.assertEqual(rv.status_code, 200) - self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) + self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) def test_returns_404_for_outside_of_organization_users(self): query = self.factory.create_query() From 6c3d5d184e6c8524b8fe227a0ba5c69cfb330345 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 12:33:14 -0600 Subject: [PATCH 36/80] test_queries --- redash/handlers/queries.py | 9 ++++++++- redash/models.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index 1164602011..960071fa3f 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -5,6 +5,8 @@ from flask_login import login_required from flask_restful import abort from funcy import distinct, take +from sqlalchemy.orm.exc import StaleDataError + from redash import models from redash.handlers.base import (BaseResource, get_object_or_404, org_scoped_rule, paginate, routes) @@ -115,11 +117,16 @@ def post(self, query_id): query_def['last_modified_by'] = self.current_user query_def['changed_by'] = self.current_user + # SQLAlchemy handles the case where a concurrent transaction beats us + # to the update. But we still have to make sure that we're not starting + # out behind. + if 'version' in query_def and query_def['version'] != query.version: + abort(409) try: self.update_model(query, query_def) models.db.session.commit() - except models.ConflictDetectedError: + except StaleDataError: abort(409) result = query.to_dict(with_visualizations=True) diff --git a/redash/models.py b/redash/models.py index 5a38cea37c..a258f7f33c 100644 --- a/redash/models.py +++ b/redash/models.py @@ -344,7 +344,7 @@ def update_group_assignments(self, group_names): db.session.add(self) def has_access(self, obj, access_type): - return AccessPermission.exists(obj, access_type, grantee=self) + return AccessPermission.exists(obj, access_type, grantee_id=self.id) class Configuration(TypeDecorator): From 271b468bcb6ec860620c3eb1bc81123186027245 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 15:48:31 -0600 Subject: [PATCH 37/80] test_alerts --- tests/models/test_alerts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/test_alerts.py b/tests/models/test_alerts.py index 3d2c3bb34b..c11aef2e6b 100644 --- a/tests/models/test_alerts.py +++ b/tests/models/test_alerts.py @@ -1,5 +1,5 @@ from tests import BaseTestCase -from redash.models import Alert +from redash.models import Alert, db class TestAlertAll(BaseTestCase): @@ -11,8 +11,9 @@ def test_returns_all_alerts_for_given_groups(self): query1 = self.factory.create_query(data_source=ds1) query2 = self.factory.create_query(data_source=ds2) - alert1 = self.factory.create_alert(query=query1) - alert2 = self.factory.create_alert(query=query2) + alert1 = self.factory.create_alert(query_rel=query1) + alert2 = self.factory.create_alert(query_rel=query2) + db.session.flush() alerts = Alert.all(group_ids=[group.id, self.factory.default_group.id]) self.assertIn(alert1, alerts) From 9b5aaa787d178d3927bd5beb06879a8ec8fb7afb Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 30 Nov 2016 22:06:39 -0600 Subject: [PATCH 38/80] test_permissions, test_changes, test_queries --- redash/handlers/permissions.py | 7 +- redash/models.py | 104 +++++++++++++---------------- redash/utils/__init__.py | 4 ++ tests/handlers/test_permissions.py | 6 +- tests/models/test_changes.py | 28 ++++---- tests/models/test_queries.py | 41 +++++++----- 6 files changed, 96 insertions(+), 94 deletions(-) diff --git a/redash/handlers/permissions.py b/redash/handlers/permissions.py index 295d19819b..d274462280 100644 --- a/redash/handlers/permissions.py +++ b/redash/handlers/permissions.py @@ -77,7 +77,10 @@ def delete(self, object_type, object_id): grantee_id = req['user_id'] access_type = req['access_type'] - AccessPermission.revoke(obj, grantee_id, access_type) + grantee = User.query.get(req['user_id']) + if grantee is None: + abort(400, message='User not found.') + AccessPermission.revoke(obj, grantee, access_type) self.record_event({ 'action': 'revoke_permission', @@ -96,6 +99,6 @@ def get(self, object_type, object_id, access_type): self.current_org) has_access = AccessPermission.exists(obj, access_type, - self.current_user.id) + self.current_user) return {'response': has_access} diff --git a/redash/models.py b/redash/models.py index a258f7f33c..6830c85f02 100644 --- a/redash/models.py +++ b/redash/models.py @@ -12,6 +12,7 @@ from flask_login import UserMixin, AnonymousUserMixin from sqlalchemy.dialects import postgresql from sqlalchemy.event import listens_for +from sqlalchemy.inspection import inspect from sqlalchemy.types import TypeDecorator from sqlalchemy.orm import object_session # noinspection PyUnresolvedReferences @@ -90,28 +91,37 @@ class ChangeTrackingMixin(object): skipped_fields = ('id', 'created_at', 'updated_at', 'version') _clean_values = None + def __init__(self, *a, **kw): + super(ChangeTrackingMixin, self).__init__(*a, **kw) + self.record_changes(self.user) + def prep_cleanvalues(self): self.__dict__['_clean_values'] = {} - for c in self.__class__.__table__.c: - self._clean_values[c.name] = None + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + # 'query' is col name but not attr name + self._clean_values[col.name] = None def __setattr__(self, key, value): if self._clean_values is None: self.prep_cleanvalues() - if key in self._clean_values: - previous = getattr(self, key) - self._clean_values[key] = previous + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + previous = getattr(self, attr.key, None) + self._clean_values[col.name] = previous super(ChangeTrackingMixin, self).__setattr__(key, value) def record_changes(self, changed_by): - changes = {} - for k, v in self._clean_values.iteritems(): - if k not in self.skipped_fields: - changes[k] = {'previous': v, 'current': getattr(self, k)} + db.session.add(self) db.session.flush() - db.session.add(Change(object_type=self.__class__.__tablename__, - object=self, + changes = {} + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + if attr.key not in self.skipped_fields: + changes[col.name] = {'previous': self._clean_values[col.name], + 'current': getattr(self, attr.key)} + db.session.add(Change(object=self, object_version=self.version, user=changed_by, change=changes)) @@ -344,7 +354,7 @@ def update_group_assignments(self, group_names): db.session.add(self) def has_access(self, obj, access_type): - return AccessPermission.exists(obj, access_type, grantee_id=self.id) + return AccessPermission.exists(obj, access_type, grantee=self) class Configuration(TypeDecorator): @@ -772,41 +782,23 @@ def recent(cls, groups, user_id=None, limit=20): return query def fork(self, user): - query = self - forked_query = Query() - forked_query.name = 'Copy of (#{}) {}'.format(query.id, query.name) - forked_query.user = user - forked_list = ['org', 'data_source', 'latest_query_data', 'description', 'query', 'query_hash'] - for a in forked_list: - setattr(forked_query, a, getattr(query, a)) - forked_query.save() - - forked_visualizations = [] - for v in query.visualizations: + forked_list = ['org', 'data_source', 'latest_query_data', 'description', + 'query_text', 'query_hash'] + kwargs = {a: getattr(self, a) for a in forked_list} + forked_query = Query(name='Copy of (#{}) {}'.format(self.id, self.name), + user=user, **kwargs) + + for v in self.visualizations: if v.type == 'TABLE': continue forked_v = v.to_dict() forked_v['options'] = v.options - forked_v['query'] = forked_query + forked_v['query_rel'] = forked_query forked_v.pop('id') - forked_visualizations.append(forked_v) - - if len(forked_visualizations) > 0: - with db.database.atomic(): - Visualization.insert_many(forked_visualizations).execute() + forked_query.visualizations.append(Visualization(**forked_v)) + db.session.add(forked_query) return forked_query - def pre_save(self, created): - super(Query, self).pre_save(created) - self.query_hash = utils.gen_query_hash(self.query) - - if self.last_modified_by is None: - self.last_modified_by = self.user - - def post_save(self, created): - if created: - self._create_default_visualizations() - def update_instance_tracked(self, changing_user, old_object=None, *args, **kwargs): self.version += 1 self.update_instance(*args, **kwargs) @@ -856,11 +848,6 @@ def create_defaults(session, ctx, *a): description='', type="TABLE", options="{}")) -@listens_for(ChangeTrackingMixin, 'init') -def create_first_change(obj, args, kwargs): - obj.record_changes(obj.user) - - class AccessPermission(GFKBase, db.Model): id = Column(db.Integer, primary_key=True) @@ -892,31 +879,31 @@ def grant(cls, obj, access_type, grantee, grantor): return grant @classmethod - def revoke(cls, obj, grantee_id, access_type=None): - permissions = cls._query(obj, access_type, grantee_id) + def revoke(cls, obj, grantee, access_type=None): + permissions = cls._query(obj, access_type, grantee) return permissions.delete() @classmethod - def find(cls, obj, access_type=None, grantee_id=None, grantor_id=None): - return cls._query(obj, access_type, grantee_id, grantor_id) + def find(cls, obj, access_type=None, grantee=None, grantor=None): + return cls._query(obj, access_type, grantee, grantor) @classmethod - def exists(cls, obj, access_type, grantee_id): - return cls.find(obj, access_type, grantee_id).count() > 0 + def exists(cls, obj, access_type, grantee): + return cls.find(obj, access_type, grantee).count() > 0 @classmethod - def _query(cls, obj, access_type=None, grantee_id=None, grantor_id=None): + def _query(cls, obj, access_type=None, grantee=None, grantor=None): q = cls.query.filter(cls.object_id == obj.id, cls.object_type == obj.__tablename__) if access_type: q.filter(AccessPermission.access_type == access_type) - if grantee_id: - q.filter(AccessPermission.grantee_id == grantee_id) + if grantee: + q.filter(AccessPermission.grantee == grantee) - if grantor_id: - q.filter(AccessPermission.grantor_id == grantor_id) + if grantor: + q.filter(AccessPermission.grantor == grantor) return q @@ -967,7 +954,10 @@ def log_change(cls, changed_by, obj): @classmethod def last_change(cls, obj): - return cls.select().where(cls.object_type==obj._meta.db_table, cls.object_id==obj.id).limit(1).first() + return db.session.query(cls).filter( + cls.object_id == obj.id, + cls.object_type == obj.__class__.__tablename__).order_by( + cls.object_version.desc()).first() class Alert(TimestampMixin, db.Model): diff --git a/redash/utils/__init__.py b/redash/utils/__init__.py index 32dfe71589..cd16f470ac 100644 --- a/redash/utils/__init__.py +++ b/redash/utils/__init__.py @@ -11,6 +11,7 @@ import pystache from funcy import distinct +from sqlalchemy.orm.query import Query from .human_time import parse_human_time from redash import settings @@ -57,6 +58,9 @@ class JSONEncoder(json.JSONEncoder): """Custom JSON encoding class, to handle Decimal and datetime.date instances.""" def default(self, o): + # Some SQLAlchemy collections are lazy. + if isinstance(o, Query): + return list(o) if isinstance(o, decimal.Decimal): return float(o) diff --git a/tests/handlers/test_permissions.py b/tests/handlers/test_permissions.py index ef1a0a76d3..6910e167c9 100644 --- a/tests/handlers/test_permissions.py +++ b/tests/handlers/test_permissions.py @@ -47,7 +47,7 @@ def test_creates_permission_if_the_user_is_an_owner(self): rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) self.assertEqual(200, rv.status_code) - self.assertTrue(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) + self.assertTrue(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) def test_returns_403_if_the_user_isnt_owner(self): query = self.factory.create_query() @@ -116,7 +116,7 @@ def test_removes_permission(self): self.assertEqual(rv.status_code, 200) - self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) + self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) def test_removes_permission_created_by_another_user(self): query = self.factory.create_query() @@ -134,7 +134,7 @@ def test_removes_permission_created_by_another_user(self): self.assertEqual(rv.status_code, 200) - self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user.id)) + self.assertFalse(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) def test_returns_404_for_outside_of_organization_users(self): query = self.factory.create_query() diff --git a/tests/models/test_changes.py b/tests/models/test_changes.py index 683aa01d95..66a35a12fc 100644 --- a/tests/models/test_changes.py +++ b/tests/models/test_changes.py @@ -18,14 +18,8 @@ class TestChangesProperty(BaseTestCase): def test_returns_initial_state(self): obj = create_object(self.factory) - for k, change in obj.changes.iteritems(): - self.assertIsNone(change['previous']) - - def test_returns_no_changes_after_save(self): - obj = create_object(self.factory) - db.session.add(obj) - - self.assertEqual({}, obj.changes) + for change in Change.query.filter(Change.object == obj): + self.assertIsNone(change.change['previous']) class TestLogChange(BaseTestCase): @@ -41,7 +35,7 @@ def obj(self): def test_properly_logs_first_creation(self): obj = create_object(self.factory) - obj.save(changed_by=self.factory.user) + obj.record_changes(changed_by=self.factory.user) change = Change.last_change(obj) self.assertIsNotNone(change) @@ -49,7 +43,7 @@ def test_properly_logs_first_creation(self): def test_skips_unnecessary_fields(self): obj = create_object(self.factory) - obj.save(changed_by=self.factory.user) + obj.record_changes(changed_by=self.factory.user) change = Change.last_change(obj) self.assertIsNotNone(change) @@ -59,9 +53,11 @@ def test_skips_unnecessary_fields(self): def test_properly_log_modification(self): obj = create_object(self.factory) - obj.save(changed_by=self.factory.user) - - obj.update_instance(name='Query 2', description='description', changed_by=self.factory.user) + obj.record_changes(changed_by=self.factory.user) + obj.name = 'Query 2' + obj.description = 'description' + db.session.flush() + obj.record_changes(changed_by=self.factory.user) change = Change.last_change(obj) @@ -72,9 +68,9 @@ def test_properly_log_modification(self): self.assertIn('description', change.change) def test_logs_create_method(self): - q = Query.create(name='Query', description='', query_text='', user=self.factory.user, - data_source=self.factory.data_source, org=self.factory.org) - + q = Query(name='Query', description='', query_text='', + user=self.factory.user, data_source=self.factory.data_source, + org=self.factory.org) change = Change.last_change(q) self.assertIsNotNone(change) diff --git a/tests/models/test_queries.py b/tests/models/test_queries.py index a6c78c9cb7..af04e92970 100644 --- a/tests/models/test_queries.py +++ b/tests/models/test_queries.py @@ -1,5 +1,5 @@ from tests import BaseTestCase -from redash.models import Query +from redash.models import Query, db class TestApiKeyGetByObject(BaseTestCase): @@ -8,20 +8,25 @@ def assert_visualizations(self, origin_q, origin_v, forked_q, forked_v): self.assertEqual(origin_v.options, forked_v.options) self.assertEqual(origin_v.type, forked_v.type) self.assertNotEqual(origin_v.id, forked_v.id) - self.assertNotEqual(origin_v.query, forked_v.query) - self.assertEqual(forked_q.id, forked_v.query.id) + self.assertNotEqual(origin_v.query_rel, forked_v.query_rel) + self.assertEqual(forked_q.id, forked_v.query_rel.id) def test_fork_with_visualizations(self): # prepare original query and visualizations - data_source = self.factory.create_data_source(group=self.factory.create_group()) - query = self.factory.create_query(data_source=data_source, description="this is description") - visualization_chart = self.factory.create_visualization(query=query, description="chart vis", type="CHART", options="""{"yAxis": [{"type": "linear"}, {"type": "linear", "opposite": true}], "series": {"stacking": null}, "globalSeriesType": "line", "sortX": true, "seriesOptions": {"count": {"zIndex": 0, "index": 0, "type": "line", "yAxis": 0}}, "xAxis": {"labels": {"enabled": true}, "type": "datetime"}, "columnMapping": {"count": "y", "created_at": "x"}, "bottomMargin": 50, "legend": {"enabled": true}}""") - visualization_box = self.factory.create_visualization(query=query, description="box vis", type="BOXPLOT", options="{}") + data_source = self.factory.create_data_source( + group=self.factory.create_group()) + query = self.factory.create_query(data_source=data_source, + description="this is description") + visualization_chart = self.factory.create_visualization( + query_rel=query, description="chart vis", type="CHART", + options="""{"yAxis": [{"type": "linear"}, {"type": "linear", "opposite": true}], "series": {"stacking": null}, "globalSeriesType": "line", "sortX": true, "seriesOptions": {"count": {"zIndex": 0, "index": 0, "type": "line", "yAxis": 0}}, "xAxis": {"labels": {"enabled": true}, "type": "datetime"}, "columnMapping": {"count": "y", "created_at": "x"}, "bottomMargin": 50, "legend": {"enabled": true}}""") + visualization_box = self.factory.create_visualization( + query_rel=query, description="box vis", type="BOXPLOT", + options="{}") fork_user = self.factory.create_user() - forked_query = query.fork(fork_user) - + db.session.flush() forked_visualization_chart = None forked_visualization_box = None @@ -35,14 +40,17 @@ def test_fork_with_visualizations(self): if v.type == "TABLE": count_table += 1 forked_table = v - self.assert_visualizations(query, visualization_chart, forked_query, forked_visualization_chart) - self.assert_visualizations(query, visualization_box, forked_query, forked_visualization_box) + self.assert_visualizations(query, visualization_chart, forked_query, + forked_visualization_chart) + self.assert_visualizations(query, visualization_box, forked_query, + forked_visualization_box) self.assertEqual(forked_query.org, query.org) self.assertEqual(forked_query.data_source, query.data_source) - self.assertEqual(forked_query.latest_query_data, query.latest_query_data) + self.assertEqual(forked_query.latest_query_data, + query.latest_query_data) self.assertEqual(forked_query.description, query.description) - self.assertEqual(forked_query.query, query.query) + self.assertEqual(forked_query.query_text, query.query_text) self.assertEqual(forked_query.query_hash, query.query_hash) self.assertEqual(forked_query.user, fork_user) self.assertEqual(forked_query.description, query.description) @@ -55,8 +63,10 @@ def test_fork_with_visualizations(self): def test_fork_from_query_that_has_no_visualization(self): # prepare original query and visualizations - data_source = self.factory.create_data_source(group=self.factory.create_group()) - query = self.factory.create_query(data_source=data_source, description="this is description") + data_source = self.factory.create_data_source( + group=self.factory.create_group()) + query = self.factory.create_query(data_source=data_source, + description="this is description") fork_user = self.factory.create_user() forked_query = query.fork(fork_user) @@ -70,4 +80,3 @@ def test_fork_from_query_that_has_no_visualization(self): self.assertEqual(count_table, 1) self.assertEqual(count_vis, 1) - From c0f48909a727f491cca6cefb0d6bdf21486940d6 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 1 Dec 2016 15:56:39 +0200 Subject: [PATCH 39/80] Fix destinations handlers code --- redash/handlers/destinations.py | 8 +++---- redash/models.py | 2 +- tests/factories.py | 8 +++++-- tests/handlers/test_destinations.py | 37 +++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 8 deletions(-) create mode 100644 tests/handlers/test_destinations.py diff --git a/redash/handlers/destinations.py b/redash/handlers/destinations.py index 1b8a74451e..98eb8af306 100644 --- a/redash/handlers/destinations.py +++ b/redash/handlers/destinations.py @@ -1,14 +1,11 @@ -import json - from flask import make_response, request from flask.ext.restful import abort -from funcy import project from redash import models from redash.permissions import require_admin from redash.destinations import destinations, get_configuration_schema_for_destination_type from redash.utils.configuration import ConfigurationContainer, ValidationError -from redash.handlers.base import BaseResource, get_object_or_404 +from redash.handlers.base import BaseResource class DestinationTypeListResource(BaseResource): @@ -88,6 +85,7 @@ def post(self): type=req['type'], options=config, user=self.current_user) - destination.save() + + models.db.session.add(destination) return destination.to_dict(all=True) diff --git a/redash/models.py b/redash/models.py index 6830c85f02..d29917dc09 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1368,7 +1368,7 @@ def destination(self): @classmethod def all(cls, org): - notification_destinations = cls.select().where(cls.org==org).order_by(cls.id.asc()) + notification_destinations = cls.query.filter(cls.org==org).order_by(cls.id.asc()) return notification_destinations diff --git a/tests/factories.py b/tests/factories.py index 210bfc357f..55b88c6210 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -127,7 +127,7 @@ def __call__(self): destination_factory = ModelFactory(redash.models.NotificationDestination, org_id=1, user=user_factory.create, - name='Destination', + name=Sequence('Destination {}'), type='slack', options=ConfigurationContainer.from_json('{"url": "https://www.slack.com"}')) @@ -337,4 +337,8 @@ def create_api_key(self, **kwargs): return api_key_factory.create(**args) def create_destination(self, **kwargs): - return destination_factory.create(**kwargs) + args = { + 'org': self.org + } + args.update(kwargs) + return destination_factory.create(**args) diff --git a/tests/handlers/test_destinations.py b/tests/handlers/test_destinations.py new file mode 100644 index 0000000000..2ed7e08725 --- /dev/null +++ b/tests/handlers/test_destinations.py @@ -0,0 +1,37 @@ +from tests import BaseTestCase + + +class TestDestinationListResource(BaseTestCase): + def test_get_returns_all_destinations(self): + d1 = self.factory.create_destination() + d2 = self.factory.create_destination() + + rv = self.make_request('get', '/api/destinations', user=self.factory.user) + self.assertEqual(len(rv.json), 2) + + def test_get_returns_only_destinations_of_current_org(self): + d1 = self.factory.create_destination() + d2 = self.factory.create_destination() + d3 = self.factory.create_destination(org=self.factory.create_org()) + + rv = self.make_request('get', '/api/destinations', user=self.factory.user) + self.assertEqual(len(rv.json), 2) + + def test_post_creates_new_destination(self): + data = { + 'options': {'addresses': 'test@example.com'}, + 'name': 'Test', + 'type': 'email' + } + rv = self.make_request('post', '/api/destinations', user=self.factory.create_admin(), data=data) + self.assertEqual(rv.status_code, 200) + pass + + def test_post_requires_admin(self): + data = { + 'options': {'addresses': 'test@example.com'}, + 'name': 'Test', + 'type': 'email' + } + rv = self.make_request('post', '/api/destinations', user=self.factory.user, data=data) + self.assertEqual(rv.status_code, 403) From 3f547990209e82ffc576f3bfa6fffbc509e645ff Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 4 Dec 2016 10:54:14 +0200 Subject: [PATCH 40/80] More destination handlers test fixes --- redash/handlers/destinations.py | 5 +++-- tests/handlers/test_destinations.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/redash/handlers/destinations.py b/redash/handlers/destinations.py index 98eb8af306..97c2eebaee 100644 --- a/redash/handlers/destinations.py +++ b/redash/handlers/destinations.py @@ -38,14 +38,15 @@ def post(self, destination_id): destination.type = req['type'] destination.name = req['name'] - destination.save() + models.db.session.add(destination) return destination.to_dict(all=True) @require_admin def delete(self, destination_id): destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org) - destination.delete_instance(recursive=True) + models.db.session.delete(destination) + models.db.session.commit() return make_response('', 204) diff --git a/tests/handlers/test_destinations.py b/tests/handlers/test_destinations.py index 2ed7e08725..c48bbfa507 100644 --- a/tests/handlers/test_destinations.py +++ b/tests/handlers/test_destinations.py @@ -1,4 +1,5 @@ from tests import BaseTestCase +from redash.models import NotificationDestination class TestDestinationListResource(BaseTestCase): @@ -35,3 +36,29 @@ def test_post_requires_admin(self): } rv = self.make_request('post', '/api/destinations', user=self.factory.user, data=data) self.assertEqual(rv.status_code, 403) + + +class TestDestinationResource(BaseTestCase): + def test_get(self): + d = self.factory.create_destination() + rv = self.make_request('get', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin()) + self.assertEqual(rv.status_code, 200) + + def test_delete(self): + d = self.factory.create_destination() + rv = self.make_request('delete', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin()) + self.assertEqual(rv.status_code, 204) + self.assertIsNone(NotificationDestination.query.get(d.id)) + + def test_post(self): + d = self.factory.create_destination() + data = { + 'name': 'updated', + 'type': d.type, + 'options': d.options.to_dict() + } + rv = self.make_request('post', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin(), data=data) + self.assertEqual(rv.status_code, 200) + self.assertEqual(NotificationDestination.query.get(d.id).name, data['name']) + + From 7d45812ef731a3371d1254870ea5b1222dfc570b Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 4 Dec 2016 13:49:10 +0200 Subject: [PATCH 41/80] Add setting to enable SQLA echo mode --- redash/settings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/redash/settings.py b/redash/settings.py index 2129db7443..a28b6aa9c1 100644 --- a/redash/settings.py +++ b/redash/settings.py @@ -67,6 +67,7 @@ def all_settings(): # Connection settings for re:dash's own database (where we store the queries, results, etc) SQLALCHEMY_DATABASE_URI = os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql:///postgres")) SQLALCHEMY_TRACK_MODIFICATIONS = False +SQLALCHEMY_ECHO = False # Celery related settings CELERY_BROKER = os.environ.get("REDASH_CELERY_BROKER", REDIS_URL) From 03b2a416c8d717fcd2c4f959c0e84b4580bf9dc9 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 4 Dec 2016 14:03:15 +0200 Subject: [PATCH 42/80] Fix queries update handler --- redash/handlers/queries.py | 7 ++----- redash/models.py | 23 +++++++++++++---------- tests/handlers/test_queries.py | 17 +++++++++++++++-- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index 960071fa3f..f65b331e5c 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -109,11 +109,8 @@ def post(self, query_id): for field in ['id', 'created_at', 'api_key', 'visualizations', 'latest_query_data', 'user', 'last_modified_by', 'org']: query_def.pop(field, None) - if 'latest_query_data_id' in query_def: - query_def['latest_query_data'] = query_def.pop('latest_query_data_id') - - if 'data_source_id' in query_def: - query_def['data_source'] = query_def.pop('data_source_id') + if 'query' in query_def: + query_def['query_text'] = query_def.pop('query') query_def['last_modified_by'] = self.current_user query_def['changed_by'] = self.current_user diff --git a/redash/models.py b/redash/models.py index d29917dc09..e13978ad8d 100644 --- a/redash/models.py +++ b/redash/models.py @@ -72,8 +72,10 @@ def object(self, value): class PseudoJSON(TypeDecorator): impl = db.Text + def process_bind_param(self, value, dialect): return json_dumps(value) + def process_result_value(self, value, dialect): if not value: return value @@ -120,15 +122,12 @@ def record_changes(self, changed_by): col, = attr.columns if attr.key not in self.skipped_fields: changes[col.name] = {'previous': self._clean_values[col.name], - 'current': getattr(self, attr.key)} - db.session.add(Change(object=self, - object_version=self.version, - user=changed_by, - change=changes)) - + 'current': getattr(self, attr.key)} -class ConflictDetectedError(Exception): - pass + db.session.add(Change(object=self, + object_version=self.version, + user=changed_by, + change=changes)) class BelongsToOrgMixin(object): @@ -645,7 +644,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True): d = { 'id': self.id, - 'latest_query_data_id': self.latest_query_data, + 'latest_query_data_id': self.latest_query_data_id, 'name': self.name, 'description': self.description, 'query': self.query_text, @@ -831,14 +830,17 @@ def groups(self): def __unicode__(self): return unicode(self.id) + @listens_for(Query.query_text, 'set') def gen_query_hash(target, val, oldval, initiator): target.query_hash = utils.gen_query_hash(val) + @listens_for(Query.user_id, 'set') def query_last_modified_by(target, val, oldval, initiator): target.last_modified_by_id = val + # Create default (table) visualization: @listens_for(SignallingSession, 'before_flush') def create_defaults(session, ctx, *a): @@ -1056,7 +1058,8 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model name = Column(db.String(100)) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User) - # XXX replace with association table + # TODO: The layout should dynamically be built from position and size information on each widget. + # Will require update in the frontend code to support this. layout = Column(db.Text) dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) diff --git a/tests/handlers/test_queries.py b/tests/handlers/test_queries.py index 368eb41b6d..e5cfc2b1ff 100644 --- a/tests/handlers/test_queries.py +++ b/tests/handlers/test_queries.py @@ -55,10 +55,23 @@ def test_update_query(self): admin = self.factory.create_admin() query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{0}'.format(query.id), data={'name': 'Testing'}, user=admin) + new_ds = self.factory.create_data_source() + new_qr = self.factory.create_query_result() + + data = { + 'name': 'Testing', + 'query': 'select 2', + 'latest_query_data_id': new_qr.id, + 'data_source_id': new_ds.id + } + + rv = self.make_request('post', '/api/queries/{0}'.format(query.id), data=data, user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], 'Testing') + self.assertEqual(rv.json['name'], data['name']) self.assertEqual(rv.json['last_modified_by']['id'], admin.id) + self.assertEqual(rv.json['query'], data['query']) + self.assertEqual(rv.json['data_source_id'], data['data_source_id']) + self.assertEqual(rv.json['latest_query_data_id'], data['latest_query_data_id']) def test_raises_error_in_case_of_conflict(self): q = self.factory.create_query() From a5805d0700b08fa7621e729f16e3c12ad2c59254 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 4 Dec 2016 14:05:24 +0200 Subject: [PATCH 43/80] Fix session api (used groups instead of group_ids) --- redash/handlers/authentication.py | 2 +- tests/handlers/test_authentication.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index c32b9d7394..a4dfb12a94 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -179,7 +179,7 @@ def session(org_slug=None): 'id': current_user.id, 'name': current_user.name, 'email': current_user.email, - 'groups': current_user.groups, + 'groups': current_user.group_ids, 'permissions': current_user.permissions } else: diff --git a/tests/handlers/test_authentication.py b/tests/handlers/test_authentication.py index 550616c04c..b0477c92aa 100644 --- a/tests/handlers/test_authentication.py +++ b/tests/handlers/test_authentication.py @@ -65,3 +65,9 @@ def test_throttle_login(self): response = self.get_request('/login', org=self.factory.org) self.assertEqual(response.status_code, 429) + + +class TestSession(BaseTestCase): + # really simple test just to trigger this route + def test_get(self): + self.make_request('get', '/api/session', user=self.factory.user) From 520835dad08fef03c7a496e4abda62403c00235e Mon Sep 17 00:00:00 2001 From: Allen Short Date: Mon, 5 Dec 2016 23:01:11 -0600 Subject: [PATCH 44/80] Fix test_cli --- redash/cli/data_sources.py | 45 ++++++++++++--------- redash/cli/database.py | 3 ++ redash/cli/groups.py | 22 ++++++++--- redash/cli/organization.py | 15 ++++--- redash/cli/users.py | 52 ++++++++++++++---------- redash/models.py | 40 +++++++++++++++++-- tests/test_cli.py | 81 +++++++++++++++++++------------------- 7 files changed, 163 insertions(+), 95 deletions(-) diff --git a/redash/cli/data_sources.py b/redash/cli/data_sources.py index 6196f1e153..b7d77e3752 100644 --- a/redash/cli/data_sources.py +++ b/redash/cli/data_sources.py @@ -2,6 +2,8 @@ import json import click +from flask.cli import with_appcontext +from sqlalchemy.orm.exc import NoResultFound from redash import models from redash.query_runner import query_runners @@ -12,6 +14,7 @@ @manager.command() +@with_appcontext @click.option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for " "all organizations).") @@ -19,9 +22,10 @@ def list(organization=None): """List currently configured data sources.""" if organization: org = models.Organization.get_by_slug(organization) - data_sources = models.DataSource.select().where(models.DataSource.org==org.id) + data_sources = models.DataSource.query.filter( + models.DataSource.org == org) else: - data_sources = models.DataSource.select() + data_sources = models.DataSource.query for i, ds in enumerate(data_sources): if i > 0: print "-" * 20 @@ -34,10 +38,12 @@ def validate_data_source_type(type): if type not in query_runners.keys(): print ("Error: the type \"{}\" is not supported (supported types: {})." .format(type, ", ".join(query_runners.keys()))) - exit() + print "OJNK" + exit(1) @manager.command() +@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to " @@ -46,10 +52,9 @@ def test(name, organization='default'): """Test connection to data source by issuing a trivial query.""" try: org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( + data_source = models.DataSource.query.filter( models.DataSource.name == name, - models.DataSource.org == org, - ) + models.DataSource.org == org).one() print "Testing connection to data source: {} (id={})".format( name, data_source.id) try: @@ -59,12 +64,13 @@ def test(name, organization='default'): exit(1) else: print "Success" - except models.DataSource.DoesNotExist: + except NoResultFound: print "Couldn't find data source named: {}".format(name) exit(1) @manager.command() +@with_appcontext @click.argument('name', default=None, required=False) @click.option('--type', default=None, help="new type for the data source") @@ -140,6 +146,7 @@ def new(name=None, type=None, options=None, organization='default'): @manager.command() +@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -148,13 +155,12 @@ def delete(name, organization='default'): """Delete data source by name.""" try: org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( - models.DataSource.name==name, - models.DataSource.org==org, - ) + data_source = models.DataSource.query.filter( + models.DataSource.name == name, + models.DataSource.org == org).one() print "Deleting data source: {} (id={})".format(name, data_source.id) - data_source.delete_instance(recursive=True) - except models.DataSource.DoesNotExist: + models.db.session.delete(data_source) + except NoResultFound: print "Couldn't find data source named: {}".format(name) exit(1) @@ -167,6 +173,7 @@ def update_attr(obj, attr, new_value): @manager.command() +@with_appcontext @click.argument('name') @click.option('--name', 'new_name', default=None, help="new name for the data source") @@ -183,10 +190,9 @@ def edit(name, new_name=None, options=None, type=None, organization='default'): if type is not None: validate_data_source_type(type) org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( - models.DataSource.name==name, - models.DataSource.org==org, - ) + data_source = models.DataSource.query.filter( + models.DataSource.name == name, + models.DataSource.org == org).one() update_attr(data_source, "name", new_name) update_attr(data_source, "type", type) @@ -197,7 +203,8 @@ def edit(name, new_name=None, options=None, type=None, organization='default'): data_source.options.set_schema(schema) data_source.options.update(options) - data_source.save() + models.db.session.add(data_source) + models.db.session.commit() - except models.DataSource.DoesNotExist: + except NoResultFound: print "Couldn't find data source named: {}".format(name) diff --git a/redash/cli/database.py b/redash/cli/database.py index 0de7e257e0..57fb234146 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -1,9 +1,11 @@ from click import Group +from flask.cli import with_appcontext manager = Group(help="Manage the database (create/drop tables).") @manager.command() +@with_appcontext def create_tables(): """Create the database tables.""" from redash.models import db, create_db, init_db @@ -13,6 +15,7 @@ def create_tables(): @manager.command() +@with_appcontext def drop_tables(): """Drop the database tables.""" from redash.models import create_db diff --git a/redash/cli/groups.py b/redash/cli/groups.py index 7aae322db7..92b264f165 100644 --- a/redash/cli/groups.py +++ b/redash/cli/groups.py @@ -1,12 +1,16 @@ from sys import exit +from sqlalchemy.orm.exc import NoResultFound +from flask.cli import with_appcontext from click import Group, argument, option + from redash import models manager = Group(help="Groups management commands.") @manager.command() +@with_appcontext @argument('name') @option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -27,13 +31,17 @@ def create(name, permissions=None, organization='default'): print "permissions: [%s]" % ",".join(permissions) try: - models.Group.create(name=name, org=org, permissions=permissions) + models.db.session.add(models.Group( + name=name, org=org, + permissions=permissions)) + models.db.session.commit() except Exception, e: print "Failed create group: %s" % e.message exit(1) @manager.command() +@with_appcontext @argument('group_id') @option('--permissions', default=None, help="Comma separated list of permissions ('create_dashboard'," @@ -45,8 +53,8 @@ def change_permissions(group_id, permissions=None): print "Change permissions of group %s ..." % group_id try: - group = models.Group.get_by_id(group_id) - except models.Group.DoesNotExist: + group = models.Group.query.get(group_id) + except NoResultFound: print "User [%s] not found." % group_id exit(1) @@ -57,7 +65,8 @@ def change_permissions(group_id, permissions=None): group.permissions = permissions try: - group.save() + models.db.session.add(group) + models.db.session.commit() except Exception, e: print "Failed change permission: %s" % e.message exit(1) @@ -73,15 +82,16 @@ def extract_permissions_string(permissions): @manager.command() +@with_appcontext @option('--org', 'organization', default=None, help="The organization to limit to (leave blank for all).") def list(organization=None): """List all groups""" if organization: org = models.Organization.get_by_slug(organization) - groups = models.Group.select().where(models.Group.org == org) + groups = models.Group.query.filter(models.Group.org == org) else: - groups = models.Group.select() + groups = models.Group.query for i, group in enumerate(groups): if i > 0: diff --git a/redash/cli/organization.py b/redash/cli/organization.py index 04e08687a8..73daef0658 100644 --- a/redash/cli/organization.py +++ b/redash/cli/organization.py @@ -1,35 +1,40 @@ from click import Group, argument +from flask.cli import with_appcontext + from redash import models manager = Group(help="Organization management commands.") @manager.command() +@with_appcontext @argument('domains') def set_google_apps_domains(domains): """ Sets the allowable domains to the comma separated list DOMAINS. """ - organization = models.Organization.select().first() + organization = models.Organization.query.first() k = models.Organization.SETTING_GOOGLE_APPS_DOMAINS organization.settings[k] = domains.split(',') - organization.save() - + models.db.session.add(organization) + models.db.session.commit() print "Updated list of allowed domains to: {}".format( organization.google_apps_domains) @manager.command() +@with_appcontext def show_google_apps_domains(): - organization = models.Organization.select().first() + organization = models.Organization.query.first() print "Current list of Google Apps domains: {}".format( ', '.join(organization.google_apps_domains)) @manager.command() +@with_appcontext def list(): """List all organizations""" - orgs = models.Organization.select() + orgs = models.Organization.query for i, org in enumerate(orgs): if i > 0: print "-" * 20 diff --git a/redash/cli/users.py b/redash/cli/users.py index 4f2ffdadfe..11049c73f2 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -1,6 +1,8 @@ from sys import exit from click import BOOL, Group, argument, option, prompt +from flask.cli import with_appcontext +from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.exc import IntegrityError from redash import models @@ -25,6 +27,7 @@ def build_groups(org, groups, is_admin): @manager.command() +@with_appcontext @argument('email') @option('--org', 'organization', default='default', help="the organization the user belongs to, (leave blank for " @@ -38,18 +41,19 @@ def grant_admin(email, organization='default'): admin_group = org.admin_group user = models.User.get_by_email_and_org(email, org) - if admin_group.id in user.groups: + if admin_group.id in user.group_ids: print "User is already an admin." else: - user.groups.append(org.admin_group.id) - user.save() - + user.group_ids = user.group_ids + [org.admin_group.id] + models.db.session.add(user) + models.db.session.commit() print "User updated." - except models.User.DoesNotExist: + except NoResultFound: print "User [%s] not found." % email @manager.command() +@with_appcontext @argument('email') @argument('name') @option('--org', 'organization', default='default', @@ -78,7 +82,7 @@ def create(email, name, groups, is_admin=False, google_auth=False, org = models.Organization.get_by_slug(organization) groups = build_groups(org, groups, is_admin) - user = models.User(org=org, email=email, name=name, groups=groups) + user = models.User(org=org, email=email, name=name, group_ids=groups) if not password and not google_auth: password = prompt("Password", hide_input=True, confirmation_prompt=True) @@ -86,13 +90,15 @@ def create(email, name, groups, is_admin=False, google_auth=False, user.hash_password(password) try: - user.save() + models.db.session.add(user) + models.db.session.commit() except Exception, e: print "Failed creating user: %s" % e.message exit(1) @manager.command() +@with_appcontext @argument('email') @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" @@ -103,16 +109,17 @@ def delete(email, organization=None): """ if organization: org = models.Organization.get_by_slug(organization) - deleted_count = models.User.delete().where( + deleted_count = models.User.query.filter( models.User.email == email, models.User.org == org.id, - ).execute() + ).delete() else: - deleted_count = models.User.delete().where(models.User.email == email).execute() + deleted_count = models.User.query.filter(models.User.email == email).delete() print "Deleted %d users." % deleted_count @manager.command() +@with_appcontext @argument('email') @argument('password') @option('--org', 'organization', default=None, @@ -124,16 +131,17 @@ def password(email, password, organization=None): """ if organization: org = models.Organization.get_by_slug(organization) - user = models.User.select().where( + user = models.User.query.filter( models.User.email == email, - models.User.org == org.id, + models.User.org == org, ).first() else: - user = models.User.select().where(models.User.email == email).first() + user = models.User.query.filter(models.User.email == email).first() if user is not None: user.hash_password(password) - user.save() + models.db.session.add(user) + models.db.session.commit() print "User updated." else: print "User [%s] not found." % email @@ -141,6 +149,7 @@ def password(email, password, organization=None): @manager.command() +@with_appcontext @argument('email') @argument('name') @argument('inviter_email') @@ -159,22 +168,23 @@ def invite(email, name, inviter_email, groups, is_admin=False, groups = build_groups(org, groups, is_admin) try: user_from = models.User.get_by_email_and_org(inviter_email, org) - user = models.User(org=org, name=name, email=email, groups=groups) - + user = models.User(org=org, name=name, email=email, group_ids=groups) + models.db.session.add(user) try: - user.save() - invite_url = invite_user(org, user_from, user) + models.db.session.commit() + invite_user(org, user_from, user) print "An invitation was sent to [%s] at [%s]." % (name, email) except IntegrityError as e: if "email" in e.message: print "Cannot invite. User already exists [%s]" % email else: print e - except models.User.DoesNotExist: + except NoResultFound: print "The inviter [%s] was not found." % inviter_email @manager.command() +@with_appcontext @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" " organizations)") @@ -182,9 +192,9 @@ def list(organization=None): """List all users""" if organization: org = models.Organization.get_by_slug(organization) - users = models.User.select().where(models.User.org==org.id) + users = models.User.query.filter(models.User.org == org) else: - users = models.User.select() + users = models.User.query for i, user in enumerate(users): if i > 0: print "-" * 20 diff --git a/redash/models.py b/redash/models.py index e13978ad8d..789907ed52 100644 --- a/redash/models.py +++ b/redash/models.py @@ -14,6 +14,7 @@ from sqlalchemy.event import listens_for from sqlalchemy.inspection import inspect from sqlalchemy.types import TypeDecorator +from sqlalchemy.ext.mutable import Mutable from sqlalchemy.orm import object_session # noinspection PyUnresolvedReferences from sqlalchemy.orm.exc import NoResultFound @@ -70,6 +71,8 @@ def object(self, value): # return peewee.Expression(self, '::', peewee.SQL(as_type)) +# XXX replace PseudoJSON and MutableDict with real JSON field + class PseudoJSON(TypeDecorator): impl = db.Text @@ -82,6 +85,33 @@ def process_result_value(self, value, dialect): return json.loads(value) +class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + + class TimestampMixin(object): updated_at = Column(db.DateTime(True), default=db.func.now(), onupdate=db.func.now(), nullable=False) @@ -186,7 +216,7 @@ class Organization(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) name = Column(db.String(255)) slug = Column(db.String(255), unique=True) - settings = Column(PseudoJSON) + settings = Column(MutableDict.as_mutable(PseudoJSON)) groups = db.relationship("Group", lazy="dynamic") __tablename__ = 'organizations' @@ -415,7 +445,9 @@ def __unicode__(self): @classmethod def create_with_group(cls, *args, **kwargs): data_source = cls(*args, **kwargs) - data_source_group = DataSourceGroup(data_source=data_source, group=data_source.org.default_group) + data_source_group = DataSourceGroup( + data_source=data_source, + group=data_source.org.default_group) db.session.add_all([data_source, data_source_group]) return data_source @@ -634,7 +666,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): is_draft = Column(db.Boolean, default=True, index=True) schedule = Column(db.String(10), nullable=True) visualizations = db.relationship("Visualization", cascade="all, delete-orphan") - options = Column(PseudoJSON, default={}) + options = Column(MutableDict.as_mutable(PseudoJSON), default={}) __tablename__ = 'queries' __mapper_args__ = { @@ -973,7 +1005,7 @@ class Alert(TimestampMixin, db.Model): query_rel = db.relationship(Query, backref='alerts', cascade="all") user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, backref='alerts') - options = Column(PseudoJSON) + options = Column(MutableDict.as_mutable(PseudoJSON)) state = Column(db.String(255), default=UNKNOWN_STATE) subscriptions = db.relationship("AlertSubscription", cascade="all, delete-orphan") last_triggered_at = Column(db.DateTime(True), nullable=True) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7215a150b0..d292433f9d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,7 +14,7 @@ show_google_apps_domains) from redash.cli.users import (create as create_user, delete as delete_user, grant_admin, invite, list as list_user, password) -from redash.models import DataSource, Group, Organization, User +from redash.models import DataSource, Group, Organization, User, db class DataSourceCommandTests(BaseTestCase): @@ -26,8 +26,8 @@ def test_interactive_new(self): input="test\n%s\n\n\nexample.com\n\ntestdb\n" % (pg_i,)) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['dbname'], 'testdb') @@ -40,8 +40,8 @@ def test_options_new(self): '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['host'], 'example.com') @@ -54,7 +54,7 @@ def test_bad_type_new(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_bad_options_new(self): runner = CliRunner() @@ -65,7 +65,7 @@ def test_bad_options_new(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_list(self): self.factory.create_data_source( @@ -122,7 +122,7 @@ def test_connection_delete(self): self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertIn('Deleting', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_connection_bad_delete(self): self.factory.create_data_source( @@ -133,7 +133,7 @@ def test_connection_bad_delete(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("Couldn't find", result.output) - self.assertEqual(DataSource.select().count(), 1) + self.assertEqual(DataSource.query.count(), 1) def test_options_edit(self): self.factory.create_data_source( @@ -147,8 +147,8 @@ def test_options_edit(self): '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test2') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['host'], 'example.com') @@ -164,7 +164,7 @@ def test_bad_type_edit(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) - ds = DataSource.select().first() + ds = DataSource.query.first() self.assertEqual(ds.type, 'sqlite') def test_bad_options_edit(self): @@ -179,7 +179,7 @@ def test_bad_options_edit(self): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) - ds = DataSource.select().first() + ds = DataSource.query.first() self.assertEqual(ds.type, 'sqlite') self.assertEqual(ds.options._config, {"dbpath": "/tmp/test.db"}) @@ -187,21 +187,21 @@ def test_bad_options_edit(self): class GroupCommandTests(BaseTestCase): def test_create(self): - gcount = Group.select().count() + gcount = Group.query.count() perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() result = runner.invoke( create_group, ['test', '--permissions', ','.join(perms)]) - print result.output self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(Group.select().count(), gcount + 1) - g = Group.select().order_by(Group.id.desc()).first() + self.assertEqual(Group.query.count(), gcount + 1) + g = Group.query.order_by(Group.id.desc()).first() self.assertEqual(g.org, self.factory.org) self.assertEqual(g.permissions, perms) def test_change_permissions(self): g = self.factory.create_group(permissions=['list_dashboards']) + db.session.flush() g_id = g.id perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() @@ -209,7 +209,7 @@ def test_change_permissions(self): change_permissions, [str(g_id), '--permissions', ','.join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - g = Group.select().where(Group.id == g_id).first() + g = Group.query.filter(Group.id == g_id).first() self.assertEqual(g.permissions, perms) def test_list(self): @@ -245,14 +245,15 @@ def test_set_google_apps_domains(self): result = runner.invoke(set_google_apps_domains, [','.join(domains)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - o = Organization.select().where( - Organization.id == self.factory.org.id).first() - self.assertEqual(o.google_apps_domains, domains) + #db.session. + db.session.refresh(self.factory.org) + self.assertEqual(self.factory.org.google_apps_domains, domains) def test_show_google_apps_domains(self): self.factory.org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ 'example.org', 'example.com'] - self.factory.org.save() + db.session.add(self.factory.org) + db.session.commit() runner = CliRunner() result = runner.invoke(show_google_apps_domains, []) self.assertFalse(result.exception) @@ -290,10 +291,10 @@ def test_create_basic(self): input="password1\npassword1\n") self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.groups, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [self.factory.default_group.id]) def test_create_admin(self): runner = CliRunner() @@ -302,10 +303,10 @@ def test_create_admin(self): '--password', 'password1', '--admin']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.groups, [self.factory.default_group.id, + self.assertEqual(u.group_ids, [self.factory.default_group.id, self.factory.admin_group.id]) def test_create_googleauth(self): @@ -314,10 +315,10 @@ def test_create_googleauth(self): create_user, ['foobar@example.com', 'Fred Foobar', '--google']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertIsNone(u.password_hash) - self.assertEqual(u.groups, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [self.factory.default_group.id]) def test_create_bad(self): self.factory.create_user(email='foobar@example.com') @@ -331,23 +332,23 @@ def test_create_bad(self): def test_delete(self): self.factory.create_user(email='foobar@example.com') - ucount = User.select().count() + ucount = User.query.count() runner = CliRunner() result = runner.invoke( delete_user, ['foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(User.select().where(User.email == + self.assertEqual(User.query.filter(User.email == "foobar@example.com").count(), 0) - self.assertEqual(User.select().count(), ucount - 1) + self.assertEqual(User.query.count(), ucount - 1) def test_delete_bad(self): - ucount = User.select().count() + ucount = User.query.count() runner = CliRunner() result = runner.invoke( delete_user, ['foobar@example.com']) self.assertIn('Deleted 0 users', result.output) - self.assertEqual(User.select().count(), ucount) + self.assertEqual(User.query.count(), ucount) def test_password(self): self.factory.create_user(email='foobar@example.com') @@ -356,7 +357,7 @@ def test_password(self): password, ['foobar@example.com', 'xyzzy']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertTrue(u.verify_password('xyzzy')) def test_password_bad(self): @@ -394,7 +395,7 @@ def test_invite(self): def test_list(self): self.factory.create_user(name='Fred Foobar', email='foobar@example.com', - organization=self.factory.org) + org=self.factory.org) runner = CliRunner() result = runner.invoke(list_user, []) self.assertFalse(result.exception) @@ -409,15 +410,15 @@ def test_list(self): textwrap.dedent(output).lstrip()) def test_grant_admin(self): - self.factory.create_user(name='Fred Foobar', + u = self.factory.create_user(name='Fred Foobar', email='foobar@example.com', org=self.factory.org, - groups=[self.factory.default_group.id]) + group_ids=[self.factory.default_group.id]) runner = CliRunner() result = runner.invoke( grant_admin, ['foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().order_by(User.id.desc()).first() - self.assertEqual(u.groups, [self.factory.default_group.id, - self.factory.admin_group.id]) + db.session.refresh(u) + self.assertEqual(u.group_ids, [self.factory.default_group.id, + self.factory.admin_group.id]) From f3d813445b4bc5af2b93921332c83b2a2cd217ad Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 29 Nov 2016 17:54:57 +0200 Subject: [PATCH 45/80] Fix data source models tests --- redash/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/redash/models.py b/redash/models.py index 789907ed52..81b9bade37 100644 --- a/redash/models.py +++ b/redash/models.py @@ -448,7 +448,6 @@ def create_with_group(cls, *args, **kwargs): data_source_group = DataSourceGroup( data_source=data_source, group=data_source.org.default_group) - db.session.add_all([data_source, data_source_group]) return data_source From 8280859ad36f0a9fc367251738df76b779c9a3b3 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 00:38:18 -0600 Subject: [PATCH 46/80] add test coverage for remaining queries to convert --- redash/handlers/groups.py | 13 +++-- redash/handlers/query_snippets.py | 27 +++++---- redash/models.py | 8 +-- redash/serializers.py | 8 +-- tests/handlers/test_embed.py | 30 ++++++++++ tests/handlers/test_groups.py | 50 ++++++++++++++++- tests/test_handlers.py | 92 ++++++++++++++++++++++++++++++- 7 files changed, 198 insertions(+), 30 deletions(-) diff --git a/redash/handlers/groups.py b/redash/handlers/groups.py index 9eb72df89f..c43a5518f9 100644 --- a/redash/handlers/groups.py +++ b/redash/handlers/groups.py @@ -25,7 +25,8 @@ def get(self): if self.current_user.has_permission('admin'): groups = models.Group.all(self.current_org) else: - groups = models.Group.select().where(models.Group.id << self.current_user.groups) + groups = models.Group.query.filter( + models.Group.id.in_(self.current_user.group_ids)) return [g.to_dict() for g in groups] @@ -133,12 +134,12 @@ def post(self, group_id): @require_admin def get(self, group_id): - group = get_object_or_404(models.Group.get_by_id_and_org, group_id, self.current_org) - + group = get_object_or_404(models.Group.get_by_id_and_org, group_id, + self.current_org) # TOOD: move to models - data_sources = models.DataSource.select(models.DataSource, models.DataSourceGroup.view_only)\ - .join(models.DataSourceGroup)\ - .where(models.DataSourceGroup.group == group) + data_sources = (models.DataSource.query + .join(models.DataSourceGroup) + .filter(models.DataSourceGroup.group == group)) return [ds.to_dict(with_permissions_for=group) for ds in data_sources] diff --git a/redash/handlers/query_snippets.py b/redash/handlers/query_snippets.py index 7614bce685..6a2f45465a 100644 --- a/redash/handlers/query_snippets.py +++ b/redash/handlers/query_snippets.py @@ -3,34 +3,38 @@ from redash import models from redash.permissions import require_admin_or_owner -from redash.handlers.base import BaseResource, require_fields, get_object_or_404 +from redash.handlers.base import (BaseResource, require_fields, + get_object_or_404) class QuerySnippetResource(BaseResource): def get(self, snippet_id): - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org) + snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, + snippet_id, self.current_org) return snippet.to_dict() def post(self, snippet_id): req = request.get_json(True) params = project(req, ('trigger', 'description', 'snippet')) - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org) + snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, + snippet_id, self.current_org) require_admin_or_owner(snippet.user.id) - snippet.update_instance(**params) + self.update_model(snippet, params) self.record_event({ 'action': 'edit', 'object_id': snippet.id, 'object_type': 'query_snippet' }) - + models.db.session.commit() return snippet.to_dict() def delete(self, snippet_id): - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org) + snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, + snippet_id, self.current_org) require_admin_or_owner(snippet.user.id) - snippet.delete_instance() + models.db.session.delete(snippet) self.record_event({ 'action': 'delete', @@ -44,21 +48,22 @@ def post(self): req = request.get_json(True) require_fields(req, ('trigger', 'description', 'snippet')) - snippet = models.QuerySnippet.create( + snippet = models.QuerySnippet( trigger=req['trigger'], description=req['description'], snippet=req['snippet'], user=self.current_user, org=self.current_org ) - + models.db.session.add(snippet) self.record_event({ 'action': 'create', 'object_id': snippet.id, 'object_type': 'query_snippet' }) - + models.db.session.commit() return snippet.to_dict() def get(self): - return [snippet.to_dict() for snippet in models.QuerySnippet.all(org=self.current_org)] + return [snippet.to_dict() for snippet in + models.QuerySnippet.all(org=self.current_org)] diff --git a/redash/models.py b/redash/models.py index 81b9bade37..3989790316 100644 --- a/redash/models.py +++ b/redash/models.py @@ -363,10 +363,6 @@ def get_by_api_key_and_org(cls, api_key, org): def all(cls, org): return cls.query.filter(cls.org == org) - @classmethod - def find_by_email(cls, email): - return cls.select().where(cls.email == email) - def __unicode__(self): return u'%s (%s)' % (self.name, self.email) @@ -435,7 +431,7 @@ def to_dict(self, all=False, with_permissions_for=None): if with_permissions_for is not None: d['view_only'] = db.session.query(DataSourceGroup.view_only).filter( DataSourceGroup.group == with_permissions_for, - DataSourceGroup.data_source == self).get() + DataSourceGroup.data_source == self).one() return d @@ -1470,7 +1466,7 @@ class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): @classmethod def all(cls, org): - return cls.select().where(cls.org==org) + return cls.query.filter(cls.org == org) def to_dict(self): d = { diff --git a/redash/serializers.py b/redash/serializers.py index f238a2b1ab..769a770eed 100644 --- a/redash/serializers.py +++ b/redash/serializers.py @@ -42,10 +42,10 @@ def public_widget(widget): def public_dashboard(dashboard): dashboard_dict = project(dashboard.to_dict(), ('name', 'layout', 'dashboard_filters_enabled', 'updated_at', 'created_at')) - widget_list = models.Widget.select(models.Widget, models.Visualization, models.Query) \ - .where(models.Widget.dashboard == dashboard.id) \ - .join(models.Visualization, join_type=models.peewee.JOIN_LEFT_OUTER) \ - .join(models.Query, join_type=models.peewee.JOIN_LEFT_OUTER) + widget_list = (models.Widget.query + .filter(models.Widget.dashboard_id == dashboard.id) + .outerjoin(models.Visualization) + .outerjoin(models.Query)) widgets = {w.id: public_widget(w) for w in widget_list} widgets_layout = [] diff --git a/tests/handlers/test_embed.py b/tests/handlers/test_embed.py index 6ae8830945..18f119d786 100644 --- a/tests/handlers/test_embed.py +++ b/tests/handlers/test_embed.py @@ -71,3 +71,33 @@ def test_inactive_token(self): # add this test. # def test_token_doesnt_belong_to_dashboard(self): # pass + +class TestAPIPublicDashboard(BaseTestCase): + def test_success(self): + dashboard = self.factory.create_dashboard() + api_key = self.factory.create_api_key(object=dashboard) + + res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), user=False, is_json=False) + self.assertEqual(res.status_code, 200) + + def test_works_for_logged_in_user(self): + dashboard = self.factory.create_dashboard() + api_key = self.factory.create_api_key(object=dashboard) + + res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), is_json=False) + self.assertEqual(res.status_code, 200) + + def test_bad_token(self): + res = self.make_request('get', '/api/dashboards/public/bad-token', user=False, is_json=False) + self.assertEqual(res.status_code, 404) + + def test_inactive_token(self): + dashboard = self.factory.create_dashboard() + api_key = self.factory.create_api_key(object=dashboard, active=False) + res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), user=False, is_json=False) + self.assertEqual(res.status_code, 404) + + # Not relevant for now, as tokens in api_keys table are only created for dashboards. Once this changes, we should + # add this test. + # def test_token_doesnt_belong_to_dashboard(self): + # pass diff --git a/tests/handlers/test_groups.py b/tests/handlers/test_groups.py index 5978869bc4..5ddd44e366 100644 --- a/tests/handlers/test_groups.py +++ b/tests/handlers/test_groups.py @@ -1,15 +1,61 @@ +from funcy import project + from tests import BaseTestCase -from redash.models import Group, DataSource, NoResultFound +from redash.models import Group, DataSource, NoResultFound, db class TestGroupDataSourceListResource(BaseTestCase): def test_returns_only_groups_for_current_org(self): group = self.factory.create_group(org=self.factory.create_org()) data_source = self.factory.create_data_source(group=group) - + db.session.flush() response = self.make_request('get', '/api/groups/{}/data_sources'.format(group.id), user=self.factory.create_admin()) self.assertEqual(response.status_code, 404) + def test_list(self): + group = self.factory.create_group() + ds = self.factory.create_data_source(group=group) + db.session.flush() + response = self.make_request( + 'get', '/api/groups/{}/data_sources'.format(group.id), + user=self.factory.create_admin()) + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response.json), 1) + self.assertEqual(response.json[0]['id'], ds.id) + + +class TestGroupResourceList(BaseTestCase): + + def test_list_admin(self): + self.factory.create_group(org=self.factory.create_org()) + response = self.make_request('get', '/api/groups', + user=self.factory.create_admin()) + g_keys = ['type', 'id', 'name', 'permissions'] + + def filtergroups(gs): + return [project(g, g_keys) for g in gs] + self.assertEqual(filtergroups(response.json), + filtergroups(g.to_dict() for g in [ + self.factory.admin_group, + self.factory.default_group])) + + def test_list(self): + group1 = self.factory.create_group(org=self.factory.create_org(), + permissions=['view_dashboard']) + db.session.flush() + u = self.factory.create_user(group_ids=[self.factory.default_group.id, + group1.id]) + db.session.flush() + response = self.make_request('get', '/api/groups', + user=u) + g_keys = ['type', 'id', 'name', 'permissions'] + + def filtergroups(gs): + return [project(g, g_keys) for g in gs] + self.assertEqual(filtergroups(response.json), + filtergroups(g.to_dict() for g in [ + self.factory.default_group, + group1])) class TestGroupResourcePost(BaseTestCase): def test_doesnt_change_builtin_groups(self): diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 2fa1381c4d..3b7752bd38 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from funcy import project from flask import url_for from flask_login import current_user @@ -302,3 +302,93 @@ def test_logout_when_loggedin(self): rv = c.get('/default/logout') self.assertEquals(rv.status_code, 302) self.assertFalse(current_user.is_authenticated) + + +class TestQuerySnippet(BaseTestCase): + def test_create(self): + res = self.make_request( + 'post', + '/api/query_snippets', + data={'trigger': 'x', 'description': 'y', 'snippet': 'z'}, + user=self.factory.user) + self.assertEqual( + project(res.json, ['id', 'trigger', 'description', 'snippet']), { + 'id': 1, + 'trigger': 'x', + 'description': 'y', + 'snippet': 'z', + }) + qs = models.QuerySnippet.query.one() + self.assertEqual(qs.trigger, 'x') + self.assertEqual(qs.description, 'y') + self.assertEqual(qs.snippet, 'z') + + def test_edit(self): + qs = models.QuerySnippet( + trigger='a', + description='b', + snippet='c', + user=self.factory.user, + org=self.factory.org + ) + models.db.session.add(qs) + models.db.session.commit() + res = self.make_request( + 'post', + '/api/query_snippets/1', + data={'trigger': 'x', 'description': 'y', 'snippet': 'z'}, + user=self.factory.user) + self.assertEqual( + project(res.json, ['id', 'trigger', 'description', 'snippet']), { + 'id': 1, + 'trigger': 'x', + 'description': 'y', + 'snippet': 'z', + }) + self.assertEqual(qs.trigger, 'x') + self.assertEqual(qs.description, 'y') + self.assertEqual(qs.snippet, 'z') + + def test_list(self): + qs = models.QuerySnippet( + trigger='x', + description='y', + snippet='z', + user=self.factory.user, + org=self.factory.org + ) + models.db.session.add(qs) + models.db.session.commit() + res = self.make_request( + 'get', + '/api/query_snippets', + user=self.factory.user) + self.assertEqual(res.status_code, 200) + data = res.json + self.assertEqual(len(data), 1) + self.assertEqual( + project(data[0], ['id', 'trigger', 'description', 'snippet']), { + 'id': 1, + 'trigger': 'x', + 'description': 'y', + 'snippet': 'z', + }) + self.assertEqual(qs.trigger, 'x') + self.assertEqual(qs.description, 'y') + self.assertEqual(qs.snippet, 'z') + + def test_delete(self): + qs = models.QuerySnippet( + trigger='a', + description='b', + snippet='c', + user=self.factory.user, + org=self.factory.org + ) + models.db.session.add(qs) + models.db.session.commit() + self.make_request( + 'delete', + '/api/query_snippets/1', + user=self.factory.user) + self.assertEqual(models.QuerySnippet.query.count(), 0) From 4edc3e3f217acdca6264b844c760e26721629871 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 03:09:41 -0600 Subject: [PATCH 47/80] gratuitous admin query fix --- redash/handlers/admin.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/redash/handlers/admin.py b/redash/handlers/admin.py index 1a24a75b43..325652b885 100644 --- a/redash/handlers/admin.py +++ b/redash/handlers/admin.py @@ -15,14 +15,17 @@ def outdated_queries(): manager_status = redis_connection.hgetall('redash:status') query_ids = json.loads(manager_status.get('query_ids', '[]')) if query_ids: - outdated_queries = models.Query.select(models.Query, models.QueryResult.retrieved_at, models.QueryResult.runtime) \ - .join(models.QueryResult, join_type=models.peewee.JOIN_LEFT_OUTER) \ - .where(models.Query.id << query_ids) \ - .order_by(models.Query.created_at.desc()) + outdated_queries = (models.db.session.query(models.Query) + .outerjoin(models.QueryResult) + .filter(models.Query.id.in_(query_ids)) + .order_by(models.Query.created_at.desc())) else: outdated_queries = [] - return json_response(dict(queries=[q.to_dict(with_stats=True, with_last_modified_by=False) for q in outdated_queries], updated_at=manager_status['last_refresh_at'])) + return json_response( + dict(queries=[q.to_dict(with_stats=True, with_last_modified_by=False) + for q in outdated_queries], + updated_at=manager_status['last_refresh_at'])) @routes.route('/api/admin/queries/tasks', methods=['GET']) From dc6bc071f1c8d1e94441a358b6675856a3a10d80 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 03:09:55 -0600 Subject: [PATCH 48/80] work around flask proxy object --- redash/handlers/authentication.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index a4dfb12a94..57d4fc1f1f 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -29,7 +29,8 @@ def get_google_auth_url(next_path): def render_token_login_page(template, org_slug, token): try: user_id = validate_token(token) - user = models.User.get_by_id_and_org(user_id, current_org) + org = current_org._get_current_object() + user = models.User.get_by_id_and_org(user_id, org) except NoResultFound: logger.exception("Bad user id in token. Token= , User id= %s, Org=%s", user_id, token, org_slug) return render_template("error.html", error_message="Invalid invite link. Please ask for a new one."), 400 @@ -79,7 +80,8 @@ def forgot_password(org_slug=None): submitted = True email = request.form['email'] try: - user = models.User.get_by_email_and_org(email, current_org) + org = current_org._get_current_object() + user = models.User.get_by_email_and_org(email, org) send_password_reset_email(user) except NoResultFound: logging.error("No user found for forgot password: %s", email) @@ -105,7 +107,8 @@ def login(org_slug=None): if request.method == 'POST': try: - user = models.User.get_by_email_and_org(request.form['email'], current_org) + org = current_org._get_current_object() + user = models.User.get_by_email_and_org(request.form['email'], org) if user and user.verify_password(request.form['password']): remember = ('remember' in request.form) login_user(user, remember=remember) From 74e1b3119fea1068cd20f186ac3a1e22d43d07f4 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 11:46:30 +0200 Subject: [PATCH 49/80] Remove dead code --- redash/models.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/redash/models.py b/redash/models.py index 3989790316..0865c61730 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1199,13 +1199,6 @@ def recent(cls, org, group_ids, user_id, for_user=False, limit=20): def get_by_slug_and_org(cls, slug, org): return cls.query.filter(cls.slug == slug, cls.org==org).one() - def tracked_save(self, changing_user, old_object=None, *args, **kwargs): - self.version += 1 - self.save(*args, **kwargs) - # save Change record - new_change = Change.save_change(user=changing_user, old_object=old_object, new_object=self) - return new_change - def __unicode__(self): return u"%s=%s" % (self.id, self.name) From 6b0e45441c620babefc5d5c46ee086a1b2ef0ce3 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 12:26:48 +0200 Subject: [PATCH 50/80] Update QuerySnippets code to SQLA --- redash/handlers/query_snippets.py | 2 + tests/factories.py | 13 ++++++ tests/handlers/test_query_snippets.py | 57 +++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 tests/handlers/test_query_snippets.py diff --git a/redash/handlers/query_snippets.py b/redash/handlers/query_snippets.py index 6a2f45465a..c933b4bf92 100644 --- a/redash/handlers/query_snippets.py +++ b/redash/handlers/query_snippets.py @@ -55,7 +55,9 @@ def post(self): user=self.current_user, org=self.current_org ) + models.db.session.add(snippet) + self.record_event({ 'action': 'create', 'object_id': snippet.id, diff --git a/tests/factories.py b/tests/factories.py index 55b88c6210..2f97d8843f 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -136,6 +136,11 @@ def __call__(self): destination=destination_factory.create, alert=alert_factory.create) +query_snippet_factory = ModelFactory(redash.models.QuerySnippet, + trigger=Sequence('trigger {}'), + description='description', + snippet='snippet') + class Factory(object): def __init__(self): @@ -342,3 +347,11 @@ def create_destination(self, **kwargs): } args.update(kwargs) return destination_factory.create(**args) + + def create_query_snippet(self, **kwargs): + args = { + 'user': self.user, + 'org': self.org + } + args.update(kwargs) + return query_snippet_factory.create(**args) diff --git a/tests/handlers/test_query_snippets.py b/tests/handlers/test_query_snippets.py new file mode 100644 index 0000000000..0a6109c539 --- /dev/null +++ b/tests/handlers/test_query_snippets.py @@ -0,0 +1,57 @@ +from tests import BaseTestCase +from redash.models import QuerySnippet + + +class TestQuerySnippetResource(BaseTestCase): + def test_get_snippet(self): + snippet = self.factory.create_query_snippet() + + rv = self.make_request('get', '/api/query_snippets/{}'.format(snippet.id)) + + for field in ('snippet', 'description', 'trigger'): + self.assertEqual(rv.json[field], getattr(snippet, field)) + + def test_update_snippet(self): + snippet = self.factory.create_query_snippet() + + data = { + 'snippet': 'updated', + 'trigger': 'updated trigger', + 'description': 'updated description' + } + + rv = self.make_request('post', '/api/query_snippets/{}'.format(snippet.id), data=data) + + for field in ('snippet', 'description', 'trigger'): + self.assertEqual(rv.json[field], data[field]) + + def test_delete_snippet(self): + snippet = self.factory.create_query_snippet() + rv = self.make_request('delete', '/api/query_snippets/{}'.format(snippet.id)) + + self.assertIsNone(QuerySnippet.query.get(snippet.id)) + + +class TestQuerySnippetListResource(BaseTestCase): + def test_create_snippet(self): + data = { + 'snippet': 'updated', + 'trigger': 'updated trigger', + 'description': 'updated description' + } + + rv = self.make_request('post', '/api/query_snippets', data=data) + self.assertEqual(rv.status_code, 200) + + def test_list_all_snippets(self): + snippet1 = self.factory.create_query_snippet() + snippet2 = self.factory.create_query_snippet() + snippet_diff_org = self.factory.create_query_snippet(org=self.factory.create_org()) + + rv = self.make_request('get', '/api/query_snippets') + ids = [s['id'] for s in rv.json] + + self.assertIn(snippet1.id, ids) + self.assertIn(snippet2.id, ids) + self.assertNotIn(snippet_diff_org.id, ids) + From 8c2b310419bf15e29d55414694272767614f06aa Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 13:54:12 +0200 Subject: [PATCH 51/80] Add tests to alerts celery task --- redash/handlers/alerts.py | 7 ++---- redash/models.py | 2 +- redash/tasks/alerts.py | 44 ++++++++++++++++++++++------------- tests/handlers/test_alerts.py | 6 +++++ tests/tasks/test_alerts.py | 32 +++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 22 deletions(-) create mode 100644 tests/tasks/test_alerts.py diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index f704aa388e..82abe07b88 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -6,7 +6,6 @@ from redash import models from redash.permissions import require_access, require_admin_or_owner, view_only, require_permission from redash.handlers.base import BaseResource, require_fields, get_object_or_404 -from sqlalchemy.exc import DataError class AlertResource(BaseResource): @@ -21,7 +20,7 @@ def post(self, alert_id): alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) require_admin_or_owner(alert.user.id) - alert.update_instance(**params) + self.update_model(alert, params) self.record_event({ 'action': 'edit', @@ -30,9 +29,7 @@ def post(self, alert_id): 'object_type': 'alert' }) - d = alert.to_dict() - models.db.session.commit() - return d + return alert.to_dict() def delete(self, alert_id): alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, diff --git a/redash/models.py b/redash/models.py index 0865c61730..2909e3bd88 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1324,6 +1324,7 @@ def record(cls, event): db.session.add(event) return event + class ApiKey(TimestampMixin, GFKBase, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1353,7 +1354,6 @@ def create_for_object(cls, object, user): class NotificationDestination(BelongsToOrgMixin, db.Model): - id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship(Organization, backref="notification_destinations") diff --git a/redash/tasks/alerts.py b/redash/tasks/alerts.py index 45c01f9f39..a431f2bb89 100644 --- a/redash/tasks/alerts.py +++ b/redash/tasks/alerts.py @@ -1,4 +1,5 @@ from celery.utils.log import get_task_logger +from flask import current_app import datetime from redash.worker import celery from redash import utils @@ -16,31 +17,42 @@ def base_url(org): return settings.HOST +def notify_subscriptions(alert, new_state): + host = base_url(alert.query_rel.org) + for subscription in alert.subscriptions: + try: + subscription.notify(alert, alert.query, subscription.user, new_state, current_app, host) + except Exception as e: + logger.exception("Error with processing destination") + + +def should_notify(alert, new_state): + passed_rearm_threshold = False + if alert.rearm and alert.last_triggered_at: + passed_rearm_threshold = alert.last_triggered_at + datetime.timedelta(seconds=alert.rearm) < utils.utcnow() + + return new_state != alert.state or (alert.state == models.Alert.TRIGGERED_STATE and passed_rearm_threshold) + + @celery.task(name="redash.tasks.check_alerts_for_query", base=BaseTask) def check_alerts_for_query(query_id): - from redash.wsgi import app - logger.debug("Checking query %d for alerts", query_id) - query = models.Query.get_by_id(query_id) + + query = models.Query.query.get(query_id) + for alert in query.alerts: - alert.query = query new_state = alert.evaluate() - passed_rearm_threshold = False - if alert.rearm and alert.last_triggered_at: - passed_rearm_threshold = alert.last_triggered_at + datetime.timedelta(seconds=alert.rearm) < utils.utcnow() - if new_state != alert.state or (alert.state == models.Alert.TRIGGERED_STATE and passed_rearm_threshold ): + + if should_notify(alert, new_state): logger.info("Alert %d new state: %s", alert.id, new_state) old_state = alert.state - alert.update_instance(state=new_state, last_triggered_at=utils.utcnow()) + + alert.state = new_state + alert.last_triggered_at = utils.utcnow() + models.db.session.commit() if old_state == models.Alert.UNKNOWN_STATE and new_state == models.Alert.OK_STATE: logger.debug("Skipping notification (previous state was unknown and now it's ok).") continue - host = base_url(alert.query.org) - for subscription in alert.subscriptions: - try: - subscription.notify(alert, query, subscription.user, new_state, app, host) - except Exception as e: - logger.exception("Error with processing destination") - + notify_subscriptions(alert, new_state) diff --git a/tests/handlers/test_alerts.py b/tests/handlers/test_alerts.py index 0bee3e3d25..2f8f679a52 100644 --- a/tests/handlers/test_alerts.py +++ b/tests/handlers/test_alerts.py @@ -27,6 +27,12 @@ def test_returns_404_if_admin_from_another_org(self): self.assertEqual(rv.status_code, 404) +class TestAlertResourcePost(BaseTestCase): + def test_updates_alert(self): + alert = self.factory.create_alert() + rv = self.make_request('post', '/api/alerts/{}'.format(alert.id), data={"name": "Testing"}) + + class TestAlertResourceDelete(BaseTestCase): def test_removes_alert_and_subscriptions(self): subscription = self.factory.create_alert_subscription() diff --git a/tests/tasks/test_alerts.py b/tests/tasks/test_alerts.py new file mode 100644 index 0000000000..9a5049cdd8 --- /dev/null +++ b/tests/tasks/test_alerts.py @@ -0,0 +1,32 @@ +from tests import BaseTestCase +from mock import MagicMock + +import redash.tasks.alerts +from redash.tasks.alerts import check_alerts_for_query, notify_subscriptions, should_notify +from redash.models import Alert + + +class TestCheckAlertsForQuery(BaseTestCase): + def test_notifies_subscribers_when_should(self): + redash.tasks.alerts.notify_subscriptions = MagicMock() + Alert.evaluate = MagicMock(return_value=Alert.TRIGGERED_STATE) + + alert = self.factory.create_alert() + check_alerts_for_query(alert.query_id) + + self.assertTrue(redash.tasks.alerts.notify_subscriptions.called) + + def test_doesnt_notify_when_nothing_changed(self): + redash.tasks.alerts.notify_subscriptions = MagicMock() + Alert.evaluate = MagicMock(return_value=Alert.OK_STATE) + + alert = self.factory.create_alert() + check_alerts_for_query(alert.query_id) + + self.assertFalse(redash.tasks.alerts.notify_subscriptions.called) + + +class TestNotifySubscriptions(BaseTestCase): + def test_calls_notify_for_subscribers(self): + subscription = self.factory.create_alert_subscription() + notify_subscriptions(subscription.alert, Alert.OK_STATE) From ecbed0087e2b61fa389a4f4c00c9eac6a8566845 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 14:07:42 +0200 Subject: [PATCH 52/80] Update create_and_login_user not to call save --- redash/authentication/google_oauth.py | 2 +- tests/test_authentication.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/redash/authentication/google_oauth.py b/redash/authentication/google_oauth.py index c02f467f47..3a215f5a9a 100644 --- a/redash/authentication/google_oauth.py +++ b/redash/authentication/google_oauth.py @@ -64,7 +64,7 @@ def create_and_login_user(org, name, email): if user_object.name != name: logger.debug("Updating user name (%r -> %r)", user_object.name, name) user_object.name = name - user_object.save() + models.db.session.commit() except NoResultFound: logger.debug("Creating user object (%r)", name) user_object = models.User(org=org, name=name, email=email, group_ids=[org.default_group.id]) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 41bff2fa57..c06d6e78e3 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -132,6 +132,13 @@ def test_creates_vaild_new_user(self): user = models.User.query.filter(models.User.email == email).one() self.assertEqual(user.email, email) + def test_updates_user_name(self): + user = self.factory.create_user(email='test@example.com') + + with patch('redash.authentication.google_oauth.login_user') as login_user_mock: + create_and_login_user(self.factory.org, "New Name", user.email) + login_user_mock.assert_called_once_with(user, remember=True) + class TestVerifyProfile(BaseTestCase): def test_no_domain_allowed_for_org(self): From 463da02be16bee583917c11b0633b11b56184080 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 14:17:57 +0200 Subject: [PATCH 53/80] remove group_hack method (unused) --- redash/models.py | 5 ----- tests/factories.py | 11 ----------- 2 files changed, 16 deletions(-) diff --git a/redash/models.py b/redash/models.py index 2909e3bd88..404a13bf42 100644 --- a/redash/models.py +++ b/redash/models.py @@ -294,11 +294,6 @@ def find_by_name(cls, org, group_names): def __unicode__(self): return unicode(self.id) -def create_group_hack(*a, **kw): - g = Group(*a, **kw) - db.session.add(g) - db.commit() - return g.id class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): id = Column(db.Integer, primary_key=True) diff --git a/tests/factories.py b/tests/factories.py index 2f97d8843f..1cc878facb 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -215,17 +215,6 @@ def create_group(self, **kwargs): g = redash.models.Group(**args) return g - def create_group_hack(self, **kwargs): - args = { - 'name': 'Group', - 'org': self.org - } - - args.update(kwargs) - - g_id = redash.models.create_group_hack(**args) - return g_id - def create_alert(self, **kwargs): args = { 'user': self.user, From 0c974bd48bdff0185be499537fab647cc3e4ca16 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 14:24:19 +0200 Subject: [PATCH 54/80] Update User.find_by_email to SQLA --- redash/models.py | 4 ++++ tests/models/test_users.py | 27 +++++++++++++++++++++++++++ tests/test_models.py | 13 ------------- 3 files changed, 31 insertions(+), 13 deletions(-) create mode 100644 tests/models/test_users.py diff --git a/redash/models.py b/redash/models.py index 404a13bf42..9b8cd2e780 100644 --- a/redash/models.py +++ b/redash/models.py @@ -358,6 +358,10 @@ def get_by_api_key_and_org(cls, api_key, org): def all(cls, org): return cls.query.filter(cls.org == org) + @classmethod + def find_by_email(cls, email): + return cls.query.filter(cls.email == email) + def __unicode__(self): return u'%s (%s)' % (self.name, self.email) diff --git a/tests/models/test_users.py b/tests/models/test_users.py new file mode 100644 index 0000000000..40ac11c813 --- /dev/null +++ b/tests/models/test_users.py @@ -0,0 +1,27 @@ +from tests import BaseTestCase + +from redash.models import User + +class TestUserUpdateGroupAssignments(BaseTestCase): + def test_default_group_always_added(self): + user = self.factory.create_user() + + user.update_group_assignments(["g_unknown"]) + self.assertItemsEqual([user.org.default_group.id], user.group_ids) + + def test_update_group_assignments(self): + user = self.factory.user + new_group = self.factory.create_group(name="g1") + + user.update_group_assignments(["g1"]) + self.assertItemsEqual([user.org.default_group.id, new_group.id], user.group_ids) + + +class TestUserFindByEmail(BaseTestCase): + def test_finds_users(self): + user = self.factory.create_user(email='test@example.com') + user2 = self.factory.create_user(email='test@example.com', org=self.factory.create_org()) + + users = User.find_by_email(user.email) + self.assertIn(user, users) + self.assertIn(user2, users) diff --git a/tests/test_models.py b/tests/test_models.py index 846773523b..e2744024fa 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -438,19 +438,6 @@ def test_returns_only_queries_in_given_groups(self): self.assertIn(q2, list(models.Query.all_queries([group1, group2]))) -class TestUser(BaseTestCase): - def test_default_group_always_added(self): - user = self.factory.create_user() - - user.update_group_assignments(["g_unknown"]) - self.assertItemsEqual([user.org.default_group.id], user.group_ids) - - def test_update_group_assignments(self): - user = self.factory.user - new_group = models.Group(id=999, name="g1", org=user.org) - - user.update_group_assignments(["g1"]) - self.assertItemsEqual([user.org.default_group.id, new_group.id], user.group_ids) class TestGroup(BaseTestCase): From 045e880f25c4c9685957d337b30828a06ac9eead Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 14:30:00 +0200 Subject: [PATCH 55/80] Add dedicated delete method to widgets instead of using an event --- redash/handlers/widgets.py | 2 +- redash/models.py | 18 +++++++++--------- tests/test_models.py | 10 +++------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/redash/handlers/widgets.py b/redash/handlers/widgets.py index 9aab8023c1..47ce325705 100644 --- a/redash/handlers/widgets.py +++ b/redash/handlers/widgets.py @@ -70,6 +70,6 @@ def delete(self, widget_id): widget = models.Widget.get_by_id_and_org(widget_id, self.current_org) require_object_modify_permission(widget.dashboard, self.current_user) - models.db.session.delete(widget) + widget.delete() return {'layout': widget.dashboard.layout, 'version': widget.dashboard.version} diff --git a/redash/models.py b/redash/models.py index 9b8cd2e780..68f9a9ebbb 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1271,6 +1271,15 @@ def to_dict(self): return d + def delete(self): + layout = json.loads(self.dashboard.layout) + layout = map(lambda row: filter(lambda w: w != self.id, row), layout) + layout = filter(lambda row: len(row) > 0, layout) + self.dashboard.layout = json.dumps(layout) + + db.session.add(self.dashboard) + db.session.delete(self) + def __unicode__(self): return u"%s" % self.id @@ -1278,15 +1287,6 @@ def __unicode__(self): def get_by_id_and_org(cls, widget_id, org): return db.session.query(cls).join(Dashboard).filter(cls.id == widget_id, Dashboard.org== org).one() -#XXX produces SQLA warning, replace with association table -@listens_for(Widget, 'before_delete') -def widget_delete(mapper, connection, self): - layout = json.loads(self.dashboard.layout) - layout = map(lambda row: filter(lambda w: w != self.id, row), layout) - layout = filter(lambda row: len(row) > 0, layout) - self.dashboard.layout = json.dumps(layout) - db.session.add(self.dashboard) - class Event(db.Model): id = Column(db.Integer, primary_key=True) diff --git a/tests/test_models.py b/tests/test_models.py index e2744024fa..a6ed7f544d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -438,8 +438,6 @@ def test_returns_only_queries_in_given_groups(self): self.assertIn(q2, list(models.Query.all_queries([group1, group2]))) - - class TestGroup(BaseTestCase): def test_returns_groups_with_specified_names(self): org1 = self.factory.create_org() @@ -565,8 +563,7 @@ def test_delete_removes_from_layout(self): widget2 = self.factory.create_widget(dashboard=widget.dashboard) db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) - db.session.delete(widget) - db.session.flush() + widget.delete() self.assertEquals(json.dumps([[widget2.id]]), widget.dashboard.layout) def test_delete_removes_empty_rows(self): @@ -575,9 +572,8 @@ def test_delete_removes_empty_rows(self): db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) db.session.flush() - db.session.delete(widget) - db.session.delete(widget2) - db.session.flush() + widget.delete() + widget2.delete() self.assertEquals("[]", widget.dashboard.layout) From fb7562645824729a7e82c21a653f25892745d4d1 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 6 Dec 2016 15:30:35 +0200 Subject: [PATCH 56/80] Create default visualization using a method instead of signal --- redash/handlers/queries.py | 2 +- redash/models.py | 24 ++++++++++++------------ tests/test_handlers.py | 3 +-- tests/test_models.py | 5 ----- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index f65b331e5c..a6e70ea1d1 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -67,7 +67,7 @@ def post(self): query_def['data_source'] = data_source query_def['org'] = self.current_org query_def['is_draft'] = True - query = models.Query(**query_def) + query = models.Query.create(**query_def) models.db.session.add(query) models.db.session.commit() diff --git a/redash/models.py b/redash/models.py index 68f9a9ebbb..3a1e2383c7 100644 --- a/redash/models.py +++ b/redash/models.py @@ -725,6 +725,16 @@ def archive(self, user=None): if user: self.record_changes(user) + @classmethod + def create(cls, **kwargs): + query = cls(**kwargs) + db.session.add(Visualization(query_rel=query, + name="Table", + description='', + type="TABLE", + options="{}")) + return query + @classmethod def all_queries(cls, groups, drafts=False): q = (cls.query.join(User, Query.user_id == User.id) @@ -810,8 +820,8 @@ def fork(self, user): forked_list = ['org', 'data_source', 'latest_query_data', 'description', 'query_text', 'query_hash'] kwargs = {a: getattr(self, a) for a in forked_list} - forked_query = Query(name='Copy of (#{}) {}'.format(self.id, self.name), - user=user, **kwargs) + forked_query = Query.create(name=u'Copy of (#{}) {}'.format(self.id, self.name), + user=user, **kwargs) for v in self.visualizations: if v.type == 'TABLE': @@ -867,16 +877,6 @@ def query_last_modified_by(target, val, oldval, initiator): target.last_modified_by_id = val -# Create default (table) visualization: -@listens_for(SignallingSession, 'before_flush') -def create_defaults(session, ctx, *a): - for obj in session.new: - if isinstance(obj, Query): - session.add(Visualization(query_rel=obj, name="Table", - description='', - type="TABLE", options="{}")) - - class AccessPermission(GFKBase, db.Model): id = Column(db.Integer, primary_key=True) # 'object' defined in GFKBase diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 3b7752bd38..229503e90b 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -89,8 +89,7 @@ def test_delete_visualization(self): rv = self.make_request('delete', '/api/visualizations/{}'.format(visualization.id)) self.assertEquals(rv.status_code, 200) - # =1 because each query has a default table visualization. - self.assertEquals(models.db.session.query(models.Visualization).count(), 1) + self.assertEquals(models.db.session.query(models.Visualization).count(), 0) def test_update_visualization(self): visualization = self.factory.create_visualization() diff --git a/tests/test_models.py b/tests/test_models.py index a6ed7f544d..12940db090 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -104,11 +104,6 @@ def test_returns_each_query_only_once(self): self.assertEqual(1, len(queries)) - def test_save_creates_default_visualization(self): - q = self.factory.create_query() - db.session.flush() - self.assertEquals(len(q.visualizations), 1) - def test_save_updates_updated_at_field(self): # This should be a test of ModelTimestampsMixin, but it's easier to test in context of existing model... :-\ one_day_ago = utcnow().date() - datetime.timedelta(days=1) From 51117e8e5bdbf5ed59f7f1f740b5689a9082cddb Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 12:42:01 +0200 Subject: [PATCH 57/80] Update dependencies to more latest versions --- requirements.txt | 19 ++++++++++--------- requirements_dev.txt | 6 +++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index d509964b80..c700b5d603 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,22 @@ -httplib2==0.9.2 Flask==0.11.1 -Flask-Admin==1.1.0 +Werkzeug==0.11.11 +Jinja2==2.8 +itsdangerous==0.24 +click==6.6 +MarkupSafe==0.23 + +httplib2==0.9.2 +Flask-Admin==1.4.2 Flask-RESTful==0.3.5 -Flask-Login==0.3.2 +Flask-Login==0.4.0 Flask-OAuthLib==0.9.2 Flask-SQLAlchemy==2.1 flask-mail==0.9.1 flask-sslify==0.1.5 +Flask-Limiter==0.9.3 passlib==1.6.2 -Jinja2==2.8 -MarkupSafe==0.23 -Werkzeug==0.11.3 aniso8601==1.1.0 blinker==1.3 -itsdangerous==0.24 psycopg2==2.5.2 python-dateutil==2.4.2 pytz==2016.7 @@ -28,7 +31,6 @@ statsd==2.1.2 gunicorn==19.4.5 celery==3.1.23 jsonschema==2.4.0 -click==6.6 RestrictedPython==3.6.0 pysaml2==2.4.0 pycrypto==2.6.1 @@ -39,5 +41,4 @@ xlsxwriter==0.9.3 pystache==0.5.4 parsedatetime==2.1 cryptography==1.4 -Flask-Limiter==0.9.3 simplejson==3.10.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index d9c92539de..4f480bbdc3 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,3 @@ -nose==1.3.0 -coverage==3.7.1 -mock==1.0.1 +nose==1.3.7 +coverage==4.0.3 +mock==2.0.0 From 73121890b3391c07b24a8240c73eddb42e0f7a0f Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 12:42:46 +0200 Subject: [PATCH 58/80] Remove usage of the deprecated flask.ext package --- redash/handlers/destinations.py | 2 +- redash/models.py | 2 -- redash/tasks/general.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/redash/handlers/destinations.py b/redash/handlers/destinations.py index 97c2eebaee..6213626b61 100644 --- a/redash/handlers/destinations.py +++ b/redash/handlers/destinations.py @@ -1,5 +1,5 @@ from flask import make_response, request -from flask.ext.restful import abort +from flask_restful import abort from redash import models from redash.permissions import require_admin diff --git a/redash/models.py b/redash/models.py index 3a1e2383c7..17b9c979a6 100644 --- a/redash/models.py +++ b/redash/models.py @@ -4,11 +4,9 @@ import itertools import json import logging -import time from funcy import project from flask_sqlalchemy import SQLAlchemy -from flask.ext.sqlalchemy import SignallingSession from flask_login import UserMixin, AnonymousUserMixin from sqlalchemy.dialects import postgresql from sqlalchemy.event import listens_for diff --git a/redash/tasks/general.py b/redash/tasks/general.py index 3aa6e6d2c1..a2af50e198 100644 --- a/redash/tasks/general.py +++ b/redash/tasks/general.py @@ -1,6 +1,6 @@ import requests from celery.utils.log import get_task_logger -from flask.ext.mail import Message +from flask_mail import Message from redash.worker import celery from redash.version_check import run_version_check from redash import models, mail, settings From 80a7d377fe22d811c48ab5d483bf8930a4aebc9f Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 12:43:17 +0200 Subject: [PATCH 59/80] Fix: set the column name of QueryResult.query_text --- redash/admin.py | 2 +- redash/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/redash/admin.py b/redash/admin.py index 731f0ffe35..d58501abe3 100644 --- a/redash/admin.py +++ b/redash/admin.py @@ -61,7 +61,7 @@ class DashboardModelView(BaseModelView): def init_admin(app): - admin = Admin(app, name='re:dash admin', template_mode='bootstrap3') + admin = Admin(app, name='Redash Admin', template_mode='bootstrap3') admin.add_view(QueryModelView(models.Query, models.db.session)) admin.add_view(QueryResultModelView(models.QueryResult, models.db.session)) diff --git a/redash/models.py b/redash/models.py index 17b9c979a6..57161318a2 100644 --- a/redash/models.py +++ b/redash/models.py @@ -535,7 +535,7 @@ class QueryResult(db.Model, BelongsToOrgMixin): data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) data_source = db.relationship(DataSource) query_hash = Column(db.String(32), index=True) - query_text = Column(db.Text) + query_text = Column('query', db.Text) data = Column(db.Text) runtime = Column(postgresql.DOUBLE_PRECISION) retrieved_at = Column(db.DateTime(True)) From bcce9cf251eed5d94626636e4e0cd2329c45b4fe Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 14:00:03 +0200 Subject: [PATCH 60/80] Update circle.yml to skip docker builds for now --- circle.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/circle.yml b/circle.yml index e1a9eb97d5..ff250851ef 100644 --- a/circle.yml +++ b/circle.yml @@ -26,11 +26,11 @@ deployment: - make pack # Skipping uploads for now, until master is stable. # - make upload - - echo "client/app" >> .dockerignore - - docker pull redash/redash:latest - - docker build -t redash/redash:$(./manage.py version | sed -e "s/\+/./") . - - docker login -e $DOCKER_EMAIL -u $DOCKER_USER -p $DOCKER_PASS - - docker push redash/redash:$(./manage.py version | sed -e "s/\+/./") + #- echo "client/app" >> .dockerignore + #- docker pull redash/redash:latest + #- docker build -t redash/redash:$(./manage.py version | sed -e "s/\+/./") . + #- docker login -e $DOCKER_EMAIL -u $DOCKER_USER -p $DOCKER_PASS + #- docker push redash/redash:$(./manage.py version | sed -e "s/\+/./") notify: webhooks: - url: https://webhooks.gitter.im/e/895d09c3165a0913ac2f From b9024b18c13cd0fcae7cd6d212d82afd7b1e26f0 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 14:00:37 +0200 Subject: [PATCH 61/80] Remove unused code --- redash/models.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/redash/models.py b/redash/models.py index 57161318a2..00b641f2e0 100644 --- a/redash/models.py +++ b/redash/models.py @@ -974,10 +974,6 @@ def to_dict(self, full=True): return d - @classmethod - def log_change(cls, changed_by, obj): - return cls.create(object=obj, object_version=obj.version, user=changed_by, change=obj.changes) - @classmethod def last_change(cls, obj): return db.session.query(cls).filter( From a84c3e25f5a044b5bb3ef581f0a218d678e2af82 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 14:58:33 +0200 Subject: [PATCH 62/80] Move CLI logic into redash.cli and uses manager for tests. Otherwise the Flask wasn't created and tests were failing. --- manage.py | 69 +----------------------- redash/cli/__init__.py | 63 ++++++++++++++++++++++ tests/test_cli.py | 116 +++++++++++++++++++---------------------- 3 files changed, 119 insertions(+), 129 deletions(-) diff --git a/manage.py b/manage.py index ebb9f5b4f7..3ccb9ecb51 100755 --- a/manage.py +++ b/manage.py @@ -2,75 +2,8 @@ """ CLI to manage redash. """ -import json - - -import click -from flask.cli import FlaskGroup, run_command - -from redash import create_app, settings, __version__ -from redash.cli import users, groups, database, data_sources, organization -from redash.monitor import get_status - - -def create(group): - app = create_app() - group.app = app - return app - - -@click.group(cls=FlaskGroup, create_app=create) -def manager(): - "Management script for redash" - - -manager.add_command(database.manager, "database") -manager.add_command(users.manager, "users") -manager.add_command(groups.manager, "groups") -manager.add_command(data_sources.manager, "ds") -manager.add_command(organization.manager, "org") -manager.add_command(run_command, "runserver") - - -@manager.command() -def version(): - """Displays re:dash version.""" - print __version__ - - -@manager.command() -def status(): - print json.dumps(get_status(), indent=2) - - -@manager.command() -def runworkers(): - """Start workers (deprecated).""" - print "** This command is deprecated. Please use Celery's CLI to control the workers. **" - - -@manager.command() -def check_settings(): - """Show the settings as re:dash sees them (useful for debugging).""" - for name, item in settings.all_settings().iteritems(): - print "{} = {}".format(name, item) - - -@manager.command() -@click.argument('email', default=settings.MAIL_DEFAULT_SENDER, required=False) -def send_test_mail(email=None): - """ - Send test message to EMAIL (default: the address you defined in MAIL_DEFAULT_SENDER) - """ - from redash import mail - from flask_mail import Message - - if email is None: - email = settings.MAIL_DEFAULT_SENDER - - mail.send(Message(subject="Test Message from re:dash", recipients=[email], - body="Test message.")) +from redash.cli import manager if __name__ == '__main__': manager() diff --git a/redash/cli/__init__.py b/redash/cli/__init__.py index e69de29bb2..0546a56e8d 100644 --- a/redash/cli/__init__.py +++ b/redash/cli/__init__.py @@ -0,0 +1,63 @@ +import json + + +import click +from flask.cli import FlaskGroup, run_command + +from redash import create_app, settings, __version__ +from redash.cli import users, groups, database, data_sources, organization +from redash.monitor import get_status + + +def create(group): + app = create_app() + group.app = app + return app + + +@click.group(cls=FlaskGroup, create_app=create) +def manager(): + """Management script for Redash""" + + +manager.add_command(database.manager, "database") +manager.add_command(users.manager, "users") +manager.add_command(groups.manager, "groups") +manager.add_command(data_sources.manager, "ds") +manager.add_command(organization.manager, "org") +manager.add_command(run_command, "runserver") + + +@manager.command() +def version(): + """Displays Redash version.""" + print __version__ + + +@manager.command() +def status(): + print json.dumps(get_status(), indent=2) + + +@manager.command() +def check_settings(): + """Show the settings as Redash sees them (useful for debugging).""" + for name, item in settings.all_settings().iteritems(): + print "{} = {}".format(name, item) + + +@manager.command() +@click.argument('email', default=settings.MAIL_DEFAULT_SENDER, required=False) +def send_test_mail(email=None): + """ + Send test message to EMAIL (default: the address you defined in MAIL_DEFAULT_SENDER) + """ + from redash import mail + from flask_mail import Message + + if email is None: + email = settings.MAIL_DEFAULT_SENDER + + mail.send(Message(subject="Test Message from Redash", recipients=[email], + body="Test message.")) + diff --git a/tests/test_cli.py b/tests/test_cli.py index d292433f9d..809026157e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,14 +6,15 @@ from tests import BaseTestCase from redash.utils.configuration import ConfigurationContainer from redash.query_runner import query_runners -from redash.cli.data_sources import (edit, delete as delete_ds, - list as list_ds, new, test) -from redash.cli.groups import (change_permissions, create as create_group, - list as list_group) -from redash.cli.organization import (list as list_org, set_google_apps_domains, - show_google_apps_domains) -from redash.cli.users import (create as create_user, delete as delete_user, - grant_admin, invite, list as list_user, password) +# from redash.cli.data_sources import (edit, delete as delete_ds, +# list as list_ds, new, test) +# from redash.cli.groups import (change_permissions, create as create_group, +# list as list_group) +# from redash.cli.organization import (list as list_org, set_google_apps_domains, +# show_google_apps_domains) +# from redash.cli.users import (create as create_user, delete as delete_user, +# grant_admin, invite, list as list_user, password) +from redash.cli import manager from redash.models import DataSource, Group, Organization, User, db @@ -22,7 +23,8 @@ def test_interactive_new(self): runner = CliRunner() pg_i = query_runners.keys().index('pg') + 1 result = runner.invoke( - new, + manager, + ['ds', 'new'], input="test\n%s\n\n\nexample.com\n\ntestdb\n" % (pg_i,)) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) @@ -35,9 +37,11 @@ def test_interactive_new(self): def test_options_new(self): runner = CliRunner() result = runner.invoke( - new, ['test', '--options', - '{"host": "example.com", "dbname": "testdb"}', - '--type', 'pg']) + manager, + ['ds', 'new', + 'test', + '--options', '{"host": "example.com", "dbname": "testdb"}', + '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(DataSource.query.count(), 1) @@ -50,7 +54,7 @@ def test_options_new(self): def test_bad_type_new(self): runner = CliRunner() result = runner.invoke( - new, ['test', '--type', 'wrong']) + manager, ['ds', 'new', 'test', '--type', 'wrong']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) @@ -59,9 +63,9 @@ def test_bad_type_new(self): def test_bad_options_new(self): runner = CliRunner() result = runner.invoke( - new, ['test', '--options', - '{"host": 12345, "dbname": "testdb"}', - '--type', 'pg']) + manager, ['ds', 'new', 'test', '--options', + '{"host": 12345, "dbname": "testdb"}', + '--type', 'pg']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) @@ -76,7 +80,7 @@ def test_list(self): name='test2', type='sqlite', options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() - result = runner.invoke(list_ds) + result = runner.invoke(manager, ['ds', 'list']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) expected_output = """ @@ -98,7 +102,7 @@ def test_connection_test(self): name='test1', type='sqlite', options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() - result = runner.invoke(test, ['test1']) + result = runner.invoke(manager, ['ds', 'test', 'test1']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertIn('Success', result.output) @@ -108,7 +112,7 @@ def test_connection_bad_test(self): name='test1', type='sqlite', options=ConfigurationContainer({"dbpath": __file__})) runner = CliRunner() - result = runner.invoke(test, ['test1']) + result = runner.invoke(manager, ['ds', 'test', 'test1']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('Failure', result.output) @@ -118,7 +122,7 @@ def test_connection_delete(self): name='test1', type='sqlite', options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() - result = runner.invoke(delete_ds, ['test1']) + result = runner.invoke(manager, ['ds', 'delete', 'test1']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertIn('Deleting', result.output) @@ -129,7 +133,7 @@ def test_connection_bad_delete(self): name='test1', type='sqlite', options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() - result = runner.invoke(delete_ds, ['wrong']) + result = runner.invoke(manager, ['ds', 'delete', 'wrong']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("Couldn't find", result.output) @@ -141,10 +145,10 @@ def test_options_edit(self): options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() result = runner.invoke( - edit, ['test1', '--options', - '{"host": "example.com", "dbname": "testdb"}', - '--name', 'test2', - '--type', 'pg']) + manager, ['ds', 'edit', 'test1', '--options', + '{"host": "example.com", "dbname": "testdb"}', + '--name', 'test2', + '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(DataSource.query.count(), 1) @@ -160,7 +164,7 @@ def test_bad_type_edit(self): options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() result = runner.invoke( - edit, ['test', '--type', 'wrong']) + manager, ['ds', 'edit', 'test', '--type', 'wrong']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) @@ -173,9 +177,9 @@ def test_bad_options_edit(self): options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) runner = CliRunner() result = runner.invoke( - new, ['test', '--options', - '{"host": 12345, "dbname": "testdb"}', - '--type', 'pg']) + manager, ['ds', 'new', 'test', '--options', + '{"host": 12345, "dbname": "testdb"}', + '--type', 'pg']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) @@ -185,13 +189,12 @@ def test_bad_options_edit(self): class GroupCommandTests(BaseTestCase): - def test_create(self): gcount = Group.query.count() perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() result = runner.invoke( - create_group, ['test', '--permissions', ','.join(perms)]) + manager, ['groups', 'create', 'test', '--permissions', ','.join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(Group.query.count(), gcount + 1) @@ -206,7 +209,7 @@ def test_change_permissions(self): perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() result = runner.invoke( - change_permissions, [str(g_id), '--permissions', ','.join(perms)]) + manager, ['groups', 'change_permissions', str(g_id), '--permissions', ','.join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) g = Group.query.filter(Group.id == g_id).first() @@ -215,7 +218,7 @@ def test_change_permissions(self): def test_list(self): self.factory.create_group(name='test', permissions=['list_dashboards']) runner = CliRunner() - result = runner.invoke(list_group, []) + result = runner.invoke(manager, ['groups', 'list']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -242,10 +245,10 @@ class OrganizationCommandTests(BaseTestCase): def test_set_google_apps_domains(self): domains = ['example.org', 'example.com'] runner = CliRunner() - result = runner.invoke(set_google_apps_domains, [','.join(domains)]) + result = runner.invoke(manager, ['org', 'set_google_apps_domains', ','.join(domains)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - #db.session. + # db.session. db.session.refresh(self.factory.org) self.assertEqual(self.factory.org.google_apps_domains, domains) @@ -255,7 +258,7 @@ def test_show_google_apps_domains(self): db.session.add(self.factory.org) db.session.commit() runner = CliRunner() - result = runner.invoke(show_google_apps_domains, []) + result = runner.invoke(manager, ['org', 'show_google_apps_domains']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -267,7 +270,7 @@ def test_show_google_apps_domains(self): def test_list(self): self.factory.create_org(name='test', slug='test_org') runner = CliRunner() - result = runner.invoke(list_org, []) + result = runner.invoke(manager, ['org', 'list']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -287,7 +290,7 @@ class UserCommandTests(BaseTestCase): def test_create_basic(self): runner = CliRunner() result = runner.invoke( - create_user, ['foobar@example.com', 'Fred Foobar'], + manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar'], input="password1\npassword1\n") self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) @@ -299,20 +302,20 @@ def test_create_basic(self): def test_create_admin(self): runner = CliRunner() result = runner.invoke( - create_user, ['foobar@example.com', 'Fred Foobar', - '--password', 'password1', '--admin']) + manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar', + '--password', 'password1', '--admin']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) self.assertEqual(u.group_ids, [self.factory.default_group.id, - self.factory.admin_group.id]) + self.factory.admin_group.id]) def test_create_googleauth(self): runner = CliRunner() result = runner.invoke( - create_user, ['foobar@example.com', 'Fred Foobar', '--google']) + manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar', '--google']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() @@ -324,7 +327,7 @@ def test_create_bad(self): self.factory.create_user(email='foobar@example.com') runner = CliRunner() result = runner.invoke( - create_user, ['foobar@example.com', 'Fred Foobar'], + manager, ['users', 'create' 'foobar@example.com', 'Fred Foobar'], input="password1\npassword1\n") self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) @@ -334,27 +337,24 @@ def test_delete(self): self.factory.create_user(email='foobar@example.com') ucount = User.query.count() runner = CliRunner() - result = runner.invoke( - delete_user, ['foobar@example.com']) + result = runner.invoke(manager, ['users', 'delete', 'foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(User.query.filter(User.email == - "foobar@example.com").count(), 0) + "foobar@example.com").count(), 0) self.assertEqual(User.query.count(), ucount - 1) def test_delete_bad(self): ucount = User.query.count() runner = CliRunner() - result = runner.invoke( - delete_user, ['foobar@example.com']) + result = runner.invoke(manager, ['users', 'delete', 'foobar@example.com']) self.assertIn('Deleted 0 users', result.output) self.assertEqual(User.query.count(), ucount) def test_password(self): self.factory.create_user(email='foobar@example.com') runner = CliRunner() - result = runner.invoke( - password, ['foobar@example.com', 'xyzzy']) + result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() @@ -362,16 +362,14 @@ def test_password(self): def test_password_bad(self): runner = CliRunner() - result = runner.invoke( - password, ['foobar@example.com', 'xyzzy']) + result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not found', result.output) def test_password_bad_org(self): runner = CliRunner() - result = runner.invoke( - password, ['foobar@example.com', 'xyzzy', '--org', 'default']) + result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy', '--org', 'default']) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not found', result.output) @@ -380,9 +378,7 @@ def test_invite(self): admin = self.factory.create_user(email='redash-admin@example.com') runner = CliRunner() with mock.patch('redash.cli.users.invite_user') as iu: - result = runner.invoke( - invite, ['foobar@example.com', 'Fred Foobar', - 'redash-admin@example.com']) + result = runner.invoke(manager, ['users', 'invite', 'foobar@example.com', 'Fred Foobar', 'redash-admin@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertTrue(iu.called) @@ -391,13 +387,12 @@ def test_invite(self): self.assertEqual(c[1].id, admin.id) self.assertEqual(c[2].email, 'foobar@example.com') - def test_list(self): self.factory.create_user(name='Fred Foobar', email='foobar@example.com', org=self.factory.org) runner = CliRunner() - result = runner.invoke(list_user, []) + result = runner.invoke(manager, ['users', 'list']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -415,8 +410,7 @@ def test_grant_admin(self): org=self.factory.org, group_ids=[self.factory.default_group.id]) runner = CliRunner() - result = runner.invoke( - grant_admin, ['foobar@example.com']) + result = runner.invoke(manager, ['users', 'grant_admin', 'foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) db.session.refresh(u) From 2b33963bee019546de6247fb992ea1197db7a352 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 15:11:42 +0200 Subject: [PATCH 63/80] Add missing db.session.commit calls in CLI --- redash/cli/__init__.py | 3 ++- redash/cli/data_sources.py | 2 ++ redash/cli/users.py | 1 + tests/test_cli.py | 6 ++---- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/redash/cli/__init__.py b/redash/cli/__init__.py index 0546a56e8d..caaf37e895 100644 --- a/redash/cli/__init__.py +++ b/redash/cli/__init__.py @@ -3,6 +3,7 @@ import click from flask.cli import FlaskGroup, run_command +from flask import current_app from redash import create_app, settings, __version__ from redash.cli import users, groups, database, data_sources, organization @@ -10,7 +11,7 @@ def create(group): - app = create_app() + app = current_app or create_app() group.app = app return app diff --git a/redash/cli/data_sources.py b/redash/cli/data_sources.py index b7d77e3752..d8b624888c 100644 --- a/redash/cli/data_sources.py +++ b/redash/cli/data_sources.py @@ -142,6 +142,7 @@ def new(name=None, type=None, options=None, organization='default'): data_source = models.DataSource.create_with_group( name=name, type=type, options=options, org=models.Organization.get_by_slug(organization)) + models.db.session.commit() print "Id: {}".format(data_source.id) @@ -160,6 +161,7 @@ def delete(name, organization='default'): models.DataSource.org == org).one() print "Deleting data source: {} (id={})".format(name, data_source.id) models.db.session.delete(data_source) + models.db.session.commit() except NoResultFound: print "Couldn't find data source named: {}".format(name) exit(1) diff --git a/redash/cli/users.py b/redash/cli/users.py index 11049c73f2..9b4d3734b4 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -115,6 +115,7 @@ def delete(email, organization=None): ).delete() else: deleted_count = models.User.query.filter(models.User.email == email).delete() + db.session.commit() print "Deleted %d users." % deleted_count diff --git a/tests/test_cli.py b/tests/test_cli.py index 809026157e..c7519eb4e8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -193,13 +193,12 @@ def test_create(self): gcount = Group.query.count() perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() - result = runner.invoke( - manager, ['groups', 'create', 'test', '--permissions', ','.join(perms)]) + result = runner.invoke(manager, ['groups', 'create', 'test', '--permissions', ','.join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(Group.query.count(), gcount + 1) g = Group.query.order_by(Group.id.desc()).first() - self.assertEqual(g.org, self.factory.org) + self.assertEqual(g.org_id, self.factory.org.id) self.assertEqual(g.permissions, perms) def test_change_permissions(self): @@ -248,7 +247,6 @@ def test_set_google_apps_domains(self): result = runner.invoke(manager, ['org', 'set_google_apps_domains', ','.join(domains)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - # db.session. db.session.refresh(self.factory.org) self.assertEqual(self.factory.org.google_apps_domains, domains) From 2d206ef47037fa76c318a880a5ae51a890dc432e Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 15:21:13 +0200 Subject: [PATCH 64/80] Switch to flask.cli.AppGroup instead of flask.cli.with_app_context --- redash/cli/data_sources.py | 9 ++------- redash/cli/database.py | 7 ++----- redash/cli/groups.py | 9 +++------ redash/cli/organization.py | 9 +++------ redash/cli/users.py | 14 ++++---------- 5 files changed, 14 insertions(+), 34 deletions(-) diff --git a/redash/cli/data_sources.py b/redash/cli/data_sources.py index d8b624888c..ed1dd3de70 100644 --- a/redash/cli/data_sources.py +++ b/redash/cli/data_sources.py @@ -2,7 +2,7 @@ import json import click -from flask.cli import with_appcontext +from flask.cli import AppGroup from sqlalchemy.orm.exc import NoResultFound from redash import models @@ -10,11 +10,10 @@ from redash.query_runner import get_configuration_schema_for_query_runner_type from redash.utils.configuration import ConfigurationContainer -manager = click.Group(help="Data sources management commands.") +manager = AppGroup(help="Data sources management commands.") @manager.command() -@with_appcontext @click.option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for " "all organizations).") @@ -43,7 +42,6 @@ def validate_data_source_type(type): @manager.command() -@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to " @@ -70,7 +68,6 @@ def test(name, organization='default'): @manager.command() -@with_appcontext @click.argument('name', default=None, required=False) @click.option('--type', default=None, help="new type for the data source") @@ -147,7 +144,6 @@ def new(name=None, type=None, options=None, organization='default'): @manager.command() -@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -175,7 +171,6 @@ def update_attr(obj, attr, new_value): @manager.command() -@with_appcontext @click.argument('name') @click.option('--name', 'new_name', default=None, help="new name for the data source") diff --git a/redash/cli/database.py b/redash/cli/database.py index 57fb234146..e291d81ce0 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -1,11 +1,9 @@ -from click import Group -from flask.cli import with_appcontext +from flask.cli import AppGroup -manager = Group(help="Manage the database (create/drop tables).") +manager = AppGroup(help="Manage the database (create/drop tables).") @manager.command() -@with_appcontext def create_tables(): """Create the database tables.""" from redash.models import db, create_db, init_db @@ -15,7 +13,6 @@ def create_tables(): @manager.command() -@with_appcontext def drop_tables(): """Drop the database tables.""" from redash.models import create_db diff --git a/redash/cli/groups.py b/redash/cli/groups.py index 92b264f165..6bd8779565 100644 --- a/redash/cli/groups.py +++ b/redash/cli/groups.py @@ -1,16 +1,15 @@ from sys import exit from sqlalchemy.orm.exc import NoResultFound -from flask.cli import with_appcontext -from click import Group, argument, option +from flask.cli import AppGroup +from click import argument, option from redash import models -manager = Group(help="Groups management commands.") +manager = AppGroup(help="Groups management commands.") @manager.command() -@with_appcontext @argument('name') @option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -41,7 +40,6 @@ def create(name, permissions=None, organization='default'): @manager.command() -@with_appcontext @argument('group_id') @option('--permissions', default=None, help="Comma separated list of permissions ('create_dashboard'," @@ -82,7 +80,6 @@ def extract_permissions_string(permissions): @manager.command() -@with_appcontext @option('--org', 'organization', default=None, help="The organization to limit to (leave blank for all).") def list(organization=None): diff --git a/redash/cli/organization.py b/redash/cli/organization.py index 73daef0658..70d4f1358c 100644 --- a/redash/cli/organization.py +++ b/redash/cli/organization.py @@ -1,13 +1,12 @@ -from click import Group, argument -from flask.cli import with_appcontext +from click import argument +from flask.cli import AppGroup from redash import models -manager = Group(help="Organization management commands.") +manager = AppGroup(help="Organization management commands.") @manager.command() -@with_appcontext @argument('domains') def set_google_apps_domains(domains): """ @@ -23,7 +22,6 @@ def set_google_apps_domains(domains): @manager.command() -@with_appcontext def show_google_apps_domains(): organization = models.Organization.query.first() print "Current list of Google Apps domains: {}".format( @@ -31,7 +29,6 @@ def show_google_apps_domains(): @manager.command() -@with_appcontext def list(): """List all organizations""" orgs = models.Organization.query diff --git a/redash/cli/users.py b/redash/cli/users.py index 9b4d3734b4..f29072dbc2 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -1,14 +1,14 @@ from sys import exit -from click import BOOL, Group, argument, option, prompt -from flask.cli import with_appcontext +from click import BOOL, argument, option, prompt +from flask.cli import AppGroup from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.exc import IntegrityError from redash import models from redash.handlers.users import invite_user -manager = Group(help="Users management commands.") +manager = AppGroup(help="Users management commands.") def build_groups(org, groups, is_admin): @@ -27,7 +27,6 @@ def build_groups(org, groups, is_admin): @manager.command() -@with_appcontext @argument('email') @option('--org', 'organization', default='default', help="the organization the user belongs to, (leave blank for " @@ -53,7 +52,6 @@ def grant_admin(email, organization='default'): @manager.command() -@with_appcontext @argument('email') @argument('name') @option('--org', 'organization', default='default', @@ -98,7 +96,6 @@ def create(email, name, groups, is_admin=False, google_auth=False, @manager.command() -@with_appcontext @argument('email') @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" @@ -115,12 +112,11 @@ def delete(email, organization=None): ).delete() else: deleted_count = models.User.query.filter(models.User.email == email).delete() - db.session.commit() + models.db.session.commit() print "Deleted %d users." % deleted_count @manager.command() -@with_appcontext @argument('email') @argument('password') @option('--org', 'organization', default=None, @@ -150,7 +146,6 @@ def password(email, password, organization=None): @manager.command() -@with_appcontext @argument('email') @argument('name') @argument('inviter_email') @@ -185,7 +180,6 @@ def invite(email, name, inviter_email, groups, is_admin=False, @manager.command() -@with_appcontext @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" " organizations)") From abf57e4e70e5adf31f11b0575f09a3462f4fcafb Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 15:33:58 +0200 Subject: [PATCH 65/80] Upgrade setuptools to install mock --- circle.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/circle.yml b/circle.yml index ff250851ef..a306813209 100644 --- a/circle.yml +++ b/circle.yml @@ -9,6 +9,7 @@ machine: 2.7.3 dependencies: pre: + - pip install --upgrade setuptools - pip install -r requirements_dev.txt - pip install -r requirements.txt - pip install pymongo==3.2.1 From 74e6ef5c1d1ac4dd6c5939ce5aca7eeb3b8fc0ea Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 15:54:08 +0200 Subject: [PATCH 66/80] Fix users CLI tests --- tests/test_cli.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index c7519eb4e8..2dd5db1aaf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,19 +1,10 @@ +import mock import textwrap - from click.testing import CliRunner -import mock from tests import BaseTestCase from redash.utils.configuration import ConfigurationContainer from redash.query_runner import query_runners -# from redash.cli.data_sources import (edit, delete as delete_ds, -# list as list_ds, new, test) -# from redash.cli.groups import (change_permissions, create as create_group, -# list as list_group) -# from redash.cli.organization import (list as list_org, set_google_apps_domains, -# show_google_apps_domains) -# from redash.cli.users import (create as create_user, delete as delete_user, -# grant_admin, invite, list as list_user, password) from redash.cli import manager from redash.models import DataSource, Group, Organization, User, db @@ -295,7 +286,7 @@ def test_create_basic(self): u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.group_ids, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [u.org.default_group.id]) def test_create_admin(self): runner = CliRunner() @@ -307,8 +298,8 @@ def test_create_admin(self): u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.group_ids, [self.factory.default_group.id, - self.factory.admin_group.id]) + self.assertEqual(u.group_ids, [u.org.default_group.id, + u.org.admin_group.id]) def test_create_googleauth(self): runner = CliRunner() @@ -319,13 +310,13 @@ def test_create_googleauth(self): u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertIsNone(u.password_hash) - self.assertEqual(u.group_ids, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [u.org.default_group.id]) def test_create_bad(self): self.factory.create_user(email='foobar@example.com') runner = CliRunner() result = runner.invoke( - manager, ['users', 'create' 'foobar@example.com', 'Fred Foobar'], + manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar'], input="password1\npassword1\n") self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) @@ -411,6 +402,5 @@ def test_grant_admin(self): result = runner.invoke(manager, ['users', 'grant_admin', 'foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - db.session.refresh(u) - self.assertEqual(u.group_ids, [self.factory.default_group.id, - self.factory.admin_group.id]) + self.assertEqual(u.group_ids, [u.org.default_group.id, + u.org.admin_group.id]) From 70d545410dc68da0f98cf395795634265004ada3 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 17:59:01 +0200 Subject: [PATCH 67/80] Add Flask-Migrate to the project Also moved old migrations to old_migrations folder (before deleting them entirely). --- migrations/README | 1 + migrations/alembic.ini | 45 ++++++++++ migrations/env.py | 87 +++++++++++++++++++ migrations/script.py.mako | 24 +++++ .../0001_allow_delete_query.py | 0 .../0002_fix_timestamp_fields.py | 0 .../0003_update_data_source_config.py | 0 .../0004_allow_null_in_event_user.py | 0 .../0005_add_updated_at.py | 0 .../0006_queries_last_edit_by.py | 0 .../0007_add_schedule_to_queries.py | 0 .../0008_make_ds_name_unique.py | 0 .../0009_add_api_key_to_user.py | 0 .../0010_allow_deleting_datasources.py | 0 .../0010_create_alerts.py | 0 .../0011_migrate_bigquery_to_json.py | 0 .../0012_add_list_users_permission.py | 0 .../0013_update_counter_options.py | 0 .../0014_add_alert_rearm_seconds.py | 0 .../0014_migrate_existing_es_to_kibana.py | 0 .../0015_add_schedule_query_permission.py | 0 .../0016_add_alert_subscriber.py | 0 .../0016_drop_tables_from_group.py | 0 .../0017_add_organization.py | 0 .../0018_add_groups_refs.py | 0 .../0019_add_super_admin_permission.py | 0 .../0020_change_ds_name_to_non_uniqe.py | 0 .../0021_create_api_keys_table.py | 0 .../0022_add_new_permissions.py | 0 .../0023_add_notification_destination.py | 0 .../0023_make_sure_correct_group_type.py | 0 .../0024_add_options_to_query.py | 0 .../0025_add_query_snippets_table.py | 0 .../0026_add_access_control_tables.py | 0 .../0026_remove_query_trackers_redis_key.py | 0 .../0027_add_draft_toggle.py | 0 redash/__init__.py | 3 + 37 files changed, 160 insertions(+) create mode 100755 migrations/README create mode 100644 migrations/alembic.ini create mode 100755 migrations/env.py create mode 100755 migrations/script.py.mako rename {migrations => old_migrations}/0001_allow_delete_query.py (100%) rename {migrations => old_migrations}/0002_fix_timestamp_fields.py (100%) rename {migrations => old_migrations}/0003_update_data_source_config.py (100%) rename {migrations => old_migrations}/0004_allow_null_in_event_user.py (100%) rename {migrations => old_migrations}/0005_add_updated_at.py (100%) rename {migrations => old_migrations}/0006_queries_last_edit_by.py (100%) rename {migrations => old_migrations}/0007_add_schedule_to_queries.py (100%) rename {migrations => old_migrations}/0008_make_ds_name_unique.py (100%) rename {migrations => old_migrations}/0009_add_api_key_to_user.py (100%) rename {migrations => old_migrations}/0010_allow_deleting_datasources.py (100%) rename {migrations => old_migrations}/0010_create_alerts.py (100%) rename {migrations => old_migrations}/0011_migrate_bigquery_to_json.py (100%) rename {migrations => old_migrations}/0012_add_list_users_permission.py (100%) rename {migrations => old_migrations}/0013_update_counter_options.py (100%) rename {migrations => old_migrations}/0014_add_alert_rearm_seconds.py (100%) rename {migrations => old_migrations}/0014_migrate_existing_es_to_kibana.py (100%) rename {migrations => old_migrations}/0015_add_schedule_query_permission.py (100%) rename {migrations => old_migrations}/0016_add_alert_subscriber.py (100%) rename {migrations => old_migrations}/0016_drop_tables_from_group.py (100%) rename {migrations => old_migrations}/0017_add_organization.py (100%) rename {migrations => old_migrations}/0018_add_groups_refs.py (100%) rename {migrations => old_migrations}/0019_add_super_admin_permission.py (100%) rename {migrations => old_migrations}/0020_change_ds_name_to_non_uniqe.py (100%) rename {migrations => old_migrations}/0021_create_api_keys_table.py (100%) rename {migrations => old_migrations}/0022_add_new_permissions.py (100%) rename {migrations => old_migrations}/0023_add_notification_destination.py (100%) rename {migrations => old_migrations}/0023_make_sure_correct_group_type.py (100%) rename {migrations => old_migrations}/0024_add_options_to_query.py (100%) rename {migrations => old_migrations}/0025_add_query_snippets_table.py (100%) rename {migrations => old_migrations}/0026_add_access_control_tables.py (100%) rename {migrations => old_migrations}/0026_remove_query_trackers_redis_key.py (100%) rename {migrations => old_migrations}/0027_add_draft_toggle.py (100%) diff --git a/migrations/README b/migrations/README new file mode 100755 index 0000000000..98e4f9c44e --- /dev/null +++ b/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/migrations/alembic.ini b/migrations/alembic.ini new file mode 100644 index 0000000000..f8ed4801f7 --- /dev/null +++ b/migrations/alembic.ini @@ -0,0 +1,45 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100755 index 0000000000..4593816063 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,87 @@ +from __future__ import with_statement +from alembic import context +from sqlalchemy import engine_from_config, pool +from logging.config import fileConfig +import logging + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +from flask import current_app +config.set_main_option('sqlalchemy.url', + current_app.config.get('SQLALCHEMY_DATABASE_URI')) +target_metadata = current_app.extensions['migrate'].db.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure(url=url) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.readthedocs.org/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + engine = engine_from_config(config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) + + connection = engine.connect() + context.configure(connection=connection, + target_metadata=target_metadata, + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args) + + try: + with context.begin_transaction(): + context.run_migrations() + finally: + connection.close() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100755 index 0000000000..2c0156303a --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/migrations/0001_allow_delete_query.py b/old_migrations/0001_allow_delete_query.py similarity index 100% rename from migrations/0001_allow_delete_query.py rename to old_migrations/0001_allow_delete_query.py diff --git a/migrations/0002_fix_timestamp_fields.py b/old_migrations/0002_fix_timestamp_fields.py similarity index 100% rename from migrations/0002_fix_timestamp_fields.py rename to old_migrations/0002_fix_timestamp_fields.py diff --git a/migrations/0003_update_data_source_config.py b/old_migrations/0003_update_data_source_config.py similarity index 100% rename from migrations/0003_update_data_source_config.py rename to old_migrations/0003_update_data_source_config.py diff --git a/migrations/0004_allow_null_in_event_user.py b/old_migrations/0004_allow_null_in_event_user.py similarity index 100% rename from migrations/0004_allow_null_in_event_user.py rename to old_migrations/0004_allow_null_in_event_user.py diff --git a/migrations/0005_add_updated_at.py b/old_migrations/0005_add_updated_at.py similarity index 100% rename from migrations/0005_add_updated_at.py rename to old_migrations/0005_add_updated_at.py diff --git a/migrations/0006_queries_last_edit_by.py b/old_migrations/0006_queries_last_edit_by.py similarity index 100% rename from migrations/0006_queries_last_edit_by.py rename to old_migrations/0006_queries_last_edit_by.py diff --git a/migrations/0007_add_schedule_to_queries.py b/old_migrations/0007_add_schedule_to_queries.py similarity index 100% rename from migrations/0007_add_schedule_to_queries.py rename to old_migrations/0007_add_schedule_to_queries.py diff --git a/migrations/0008_make_ds_name_unique.py b/old_migrations/0008_make_ds_name_unique.py similarity index 100% rename from migrations/0008_make_ds_name_unique.py rename to old_migrations/0008_make_ds_name_unique.py diff --git a/migrations/0009_add_api_key_to_user.py b/old_migrations/0009_add_api_key_to_user.py similarity index 100% rename from migrations/0009_add_api_key_to_user.py rename to old_migrations/0009_add_api_key_to_user.py diff --git a/migrations/0010_allow_deleting_datasources.py b/old_migrations/0010_allow_deleting_datasources.py similarity index 100% rename from migrations/0010_allow_deleting_datasources.py rename to old_migrations/0010_allow_deleting_datasources.py diff --git a/migrations/0010_create_alerts.py b/old_migrations/0010_create_alerts.py similarity index 100% rename from migrations/0010_create_alerts.py rename to old_migrations/0010_create_alerts.py diff --git a/migrations/0011_migrate_bigquery_to_json.py b/old_migrations/0011_migrate_bigquery_to_json.py similarity index 100% rename from migrations/0011_migrate_bigquery_to_json.py rename to old_migrations/0011_migrate_bigquery_to_json.py diff --git a/migrations/0012_add_list_users_permission.py b/old_migrations/0012_add_list_users_permission.py similarity index 100% rename from migrations/0012_add_list_users_permission.py rename to old_migrations/0012_add_list_users_permission.py diff --git a/migrations/0013_update_counter_options.py b/old_migrations/0013_update_counter_options.py similarity index 100% rename from migrations/0013_update_counter_options.py rename to old_migrations/0013_update_counter_options.py diff --git a/migrations/0014_add_alert_rearm_seconds.py b/old_migrations/0014_add_alert_rearm_seconds.py similarity index 100% rename from migrations/0014_add_alert_rearm_seconds.py rename to old_migrations/0014_add_alert_rearm_seconds.py diff --git a/migrations/0014_migrate_existing_es_to_kibana.py b/old_migrations/0014_migrate_existing_es_to_kibana.py similarity index 100% rename from migrations/0014_migrate_existing_es_to_kibana.py rename to old_migrations/0014_migrate_existing_es_to_kibana.py diff --git a/migrations/0015_add_schedule_query_permission.py b/old_migrations/0015_add_schedule_query_permission.py similarity index 100% rename from migrations/0015_add_schedule_query_permission.py rename to old_migrations/0015_add_schedule_query_permission.py diff --git a/migrations/0016_add_alert_subscriber.py b/old_migrations/0016_add_alert_subscriber.py similarity index 100% rename from migrations/0016_add_alert_subscriber.py rename to old_migrations/0016_add_alert_subscriber.py diff --git a/migrations/0016_drop_tables_from_group.py b/old_migrations/0016_drop_tables_from_group.py similarity index 100% rename from migrations/0016_drop_tables_from_group.py rename to old_migrations/0016_drop_tables_from_group.py diff --git a/migrations/0017_add_organization.py b/old_migrations/0017_add_organization.py similarity index 100% rename from migrations/0017_add_organization.py rename to old_migrations/0017_add_organization.py diff --git a/migrations/0018_add_groups_refs.py b/old_migrations/0018_add_groups_refs.py similarity index 100% rename from migrations/0018_add_groups_refs.py rename to old_migrations/0018_add_groups_refs.py diff --git a/migrations/0019_add_super_admin_permission.py b/old_migrations/0019_add_super_admin_permission.py similarity index 100% rename from migrations/0019_add_super_admin_permission.py rename to old_migrations/0019_add_super_admin_permission.py diff --git a/migrations/0020_change_ds_name_to_non_uniqe.py b/old_migrations/0020_change_ds_name_to_non_uniqe.py similarity index 100% rename from migrations/0020_change_ds_name_to_non_uniqe.py rename to old_migrations/0020_change_ds_name_to_non_uniqe.py diff --git a/migrations/0021_create_api_keys_table.py b/old_migrations/0021_create_api_keys_table.py similarity index 100% rename from migrations/0021_create_api_keys_table.py rename to old_migrations/0021_create_api_keys_table.py diff --git a/migrations/0022_add_new_permissions.py b/old_migrations/0022_add_new_permissions.py similarity index 100% rename from migrations/0022_add_new_permissions.py rename to old_migrations/0022_add_new_permissions.py diff --git a/migrations/0023_add_notification_destination.py b/old_migrations/0023_add_notification_destination.py similarity index 100% rename from migrations/0023_add_notification_destination.py rename to old_migrations/0023_add_notification_destination.py diff --git a/migrations/0023_make_sure_correct_group_type.py b/old_migrations/0023_make_sure_correct_group_type.py similarity index 100% rename from migrations/0023_make_sure_correct_group_type.py rename to old_migrations/0023_make_sure_correct_group_type.py diff --git a/migrations/0024_add_options_to_query.py b/old_migrations/0024_add_options_to_query.py similarity index 100% rename from migrations/0024_add_options_to_query.py rename to old_migrations/0024_add_options_to_query.py diff --git a/migrations/0025_add_query_snippets_table.py b/old_migrations/0025_add_query_snippets_table.py similarity index 100% rename from migrations/0025_add_query_snippets_table.py rename to old_migrations/0025_add_query_snippets_table.py diff --git a/migrations/0026_add_access_control_tables.py b/old_migrations/0026_add_access_control_tables.py similarity index 100% rename from migrations/0026_add_access_control_tables.py rename to old_migrations/0026_add_access_control_tables.py diff --git a/migrations/0026_remove_query_trackers_redis_key.py b/old_migrations/0026_remove_query_trackers_redis_key.py similarity index 100% rename from migrations/0026_remove_query_trackers_redis_key.py rename to old_migrations/0026_remove_query_trackers_redis_key.py diff --git a/migrations/0027_add_draft_toggle.py b/old_migrations/0027_add_draft_toggle.py similarity index 100% rename from migrations/0027_add_draft_toggle.py rename to old_migrations/0027_add_draft_toggle.py diff --git a/redash/__init__.py b/redash/__init__.py index 593fc0f2bb..60af6b265b 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -9,6 +9,7 @@ from flask_mail import Mail from flask_limiter import Limiter from flask_limiter.util import get_ipaddr +from flask_migrate import Migrate from redash import settings from redash.query_runner import import_query_runners @@ -52,6 +53,7 @@ def create_redis_connection(): setup_logging() redis_connection = create_redis_connection() mail = Mail() +migrate = Migrate() mail.init_mail(settings.all_settings()) statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX) limiter = Limiter(key_func=get_ipaddr, storage_uri=settings.REDIS_URL) @@ -110,6 +112,7 @@ def create_app(): provision_app(app) db.init_app(app) + migrate.init_app(app, db) init_admin(app) mail.init_app(app) setup_authentication(app) From 923c463e4a8e2da3b2cbe60d7351a3036257551d Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Wed, 7 Dec 2016 18:32:49 +0200 Subject: [PATCH 68/80] Add migration for the is_draft column --- ...746_add_is_draft_status_to_queries_and_.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py diff --git a/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py new file mode 100644 index 0000000000..c741420085 --- /dev/null +++ b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py @@ -0,0 +1,37 @@ +"""Add is_draft status to queries and dashboards + +Revision ID: 65fc9ede4746 +Revises: +Create Date: 2016-12-07 18:08:13.395586 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +from sqlalchemy.exc import ProgrammingError + +revision = '65fc9ede4746' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + op.get_bind() + try: + op.add_column('queries', sa.Column('is_draft', sa.Boolean, default=True, index=True)) + op.add_column('dashboards', sa.Column('is_draft', sa.Boolean, default=True, index=True)) + op.execute("UPDATE queries SET is_draft = (name = 'New Query')") + op.execute("UPDATE dashboards SET is_draft = false") + except ProgrammingError as e: + # The columns might exist if you ran the old migrations. + if 'column "is_draft" of relation "queries" already exists' in e.message: + print "*** Skipping creationg of is_draft columns as they already exist." + op.execute("ROLLBACK") + + +def downgrade(): + op.drop_column('queries', 'is_draft') + op.drop_column('dashboards', 'is_draft') From da31d983b751c4c6f21b24284b96ad0e45ab7f1b Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 13:18:07 -0600 Subject: [PATCH 69/80] Fix test_cli tests --- tests/test_cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2dd5db1aaf..97b65476e7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -189,6 +189,7 @@ def test_create(self): self.assertEqual(result.exit_code, 0) self.assertEqual(Group.query.count(), gcount + 1) g = Group.query.order_by(Group.id.desc()).first() + db.session.add(self.factory.org) self.assertEqual(g.org_id, self.factory.org.id) self.assertEqual(g.permissions, perms) @@ -238,7 +239,7 @@ def test_set_google_apps_domains(self): result = runner.invoke(manager, ['org', 'set_google_apps_domains', ','.join(domains)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - db.session.refresh(self.factory.org) + db.session.add(self.factory.org) self.assertEqual(self.factory.org.google_apps_domains, domains) def test_show_google_apps_domains(self): @@ -372,6 +373,7 @@ def test_invite(self): self.assertEqual(result.exit_code, 0) self.assertTrue(iu.called) c = iu.call_args[0] + db.session.add_all(c) self.assertEqual(c[0].id, self.factory.org.id) self.assertEqual(c[1].id, admin.id) self.assertEqual(c[2].email, 'foobar@example.com') @@ -402,5 +404,6 @@ def test_grant_admin(self): result = runner.invoke(manager, ['users', 'grant_admin', 'foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) + db.session.add(u) self.assertEqual(u.group_ids, [u.org.default_group.id, u.org.admin_group.id]) From 4ba399af673a1c75ca1cdd7d3ff73563e1db66c8 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 17:57:21 -0600 Subject: [PATCH 70/80] fix basic query execution from UI --- redash/tasks/queries.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/redash/tasks/queries.py b/redash/tasks/queries.py index 1174e480e5..d7fd5d913f 100644 --- a/redash/tasks/queries.py +++ b/redash/tasks/queries.py @@ -332,8 +332,8 @@ def cleanup_query_results(): unused_query_results = models.QueryResult.unused(settings.QUERY_RESULTS_CLEANUP_MAX_AGE).limit(settings.QUERY_RESULTS_CLEANUP_COUNT) total_unused_query_results = models.QueryResult.unused().count() - deleted_count = models.QueryResult.delete().where(models.QueryResult.id << unused_query_results).execute() - + deleted_count = unused_query_results.delete() + models.db.session.commit() logger.info("Deleted %d unused query results out of total of %d." % (deleted_count, total_unused_query_results)) @@ -385,7 +385,7 @@ def __init__(self, task, query, data_source_id, user_id, metadata): self.metadata = metadata self.data_source = self._load_data_source() if user_id is not None: - self.user = models.User.get_by_id(user_id) + self.user = models.User.query.get(user_id) else: self.user = None self.query_hash = gen_query_hash(self.query) @@ -424,16 +424,17 @@ def run(self): self.tracker.update(state='failed') result = QueryExecutionError(error) else: - query_result, updated_query_ids = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, - self.query_hash, self.query, data, - run_time, utils.utcnow()) + query_result, updated_query_ids = models.QueryResult.store_result( + self.data_source.org, self.data_source, + self.query_hash, self.query, data, + run_time, utils.utcnow()) self._log_progress('checking_alerts') for query_id in updated_query_ids: check_alerts_for_query.delay(query_id) self._log_progress('finished') result = query_result.id - + models.db.session.commit() return result def _annotate_query(self, query_runner): @@ -457,7 +458,7 @@ def _log_progress(self, state): def _load_data_source(self): logger.info("task=execute_query state=load_ds ds_id=%d", self.data_source_id) - return models.DataSource.get_by_id(self.data_source_id) + return models.DataSource.query.get(self.data_source_id) # user_id is added last as a keyword argument for backward compatability -- to support executing previously submitted From 4945d0bec76a63e9db46b9cfe919d40480c71381 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 7 Dec 2016 19:59:48 -0600 Subject: [PATCH 71/80] fix cleanup_query_results task --- redash/models.py | 2 +- redash/tasks/queries.py | 4 +++- tests/test_models.py | 11 ++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/redash/models.py b/redash/models.py index 00b641f2e0..94bcf3bfda 100644 --- a/redash/models.py +++ b/redash/models.py @@ -557,7 +557,7 @@ def to_dict(self): def unused(cls, days=7): age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) - unused_results = (db.session.query(QueryResult).filter( + unused_results = (db.session.query(QueryResult.id).filter( Query.id == None, QueryResult.retrieved_at < age_threshold) .outerjoin(Query)) diff --git a/redash/tasks/queries.py b/redash/tasks/queries.py index d7fd5d913f..5f30e2fae1 100644 --- a/redash/tasks/queries.py +++ b/redash/tasks/queries.py @@ -332,7 +332,9 @@ def cleanup_query_results(): unused_query_results = models.QueryResult.unused(settings.QUERY_RESULTS_CLEANUP_MAX_AGE).limit(settings.QUERY_RESULTS_CLEANUP_COUNT) total_unused_query_results = models.QueryResult.unused().count() - deleted_count = unused_query_results.delete() + deleted_count = models.Query.query.filter( + models.Query.id.in_(unused_query_results.subquery()) + ).delete(synchronize_session=False) models.db.session.commit() logger.info("Deleted %d unused query results out of total of %d." % (deleted_count, total_unused_query_results)) diff --git a/tests/test_models.py b/tests/test_models.py index 12940db090..0e6fdc5c16 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -393,18 +393,19 @@ def test_returns_only_unused_query_results(self): two_weeks_ago = utcnow() - datetime.timedelta(days=14) qr = self.factory.create_query_result() query = self.factory.create_query(latest_query_data=qr) + db.session.flush() unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) - - self.assertIn(unused_qr, models.QueryResult.unused()) - self.assertNotIn(qr, models.QueryResult.unused()) + self.assertIn((unused_qr.id,), models.QueryResult.unused()) + self.assertNotIn((qr.id,), list(models.QueryResult.unused())) def test_returns_only_over_a_week_old_results(self): two_weeks_ago = utcnow() - datetime.timedelta(days=14) unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) + db.session.flush() new_unused_qr = self.factory.create_query_result() - self.assertIn(unused_qr, models.QueryResult.unused()) - self.assertNotIn(new_unused_qr, models.QueryResult.unused()) + self.assertIn((unused_qr.id,), models.QueryResult.unused()) + self.assertNotIn((new_unused_qr.id,), models.QueryResult.unused()) class TestQueryAll(BaseTestCase): From 3cce4d0ce484e1790feea3c0d5d73a849bca87f5 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 14:59:21 +0200 Subject: [PATCH 72/80] Add warning script to migrations folder --- migrations/0001_warning.py | 9 +++++++++ redash/models.py | 7 ------- 2 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 migrations/0001_warning.py diff --git a/migrations/0001_warning.py b/migrations/0001_warning.py new file mode 100644 index 0000000000..f298121e47 --- /dev/null +++ b/migrations/0001_warning.py @@ -0,0 +1,9 @@ +# This is here just to print a warning for users who use the old Fabric upgrade script. + +if __name__ == '__main__': + warning = "You're using an outdated upgrade script that is running migrations the wrong way. Please upgrade to " \ + "newer version of the script before continuning the upgrade process." + print "*" * 20 + print warning + print "*" * 20 + exit(1) diff --git a/redash/models.py b/redash/models.py index 94bcf3bfda..a3d628e156 100644 --- a/redash/models.py +++ b/redash/models.py @@ -63,14 +63,7 @@ def object(self, value): self.object_id = value.id -# # Support for cast operation on database fields -# @peewee.Node.extend() -# def cast(self, as_type): -# return peewee.Expression(self, '::', peewee.SQL(as_type)) - - # XXX replace PseudoJSON and MutableDict with real JSON field - class PseudoJSON(TypeDecorator): impl = db.Text From 12cbfe1af4dd379b723ef2d6552241a2af3c6a26 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 15:35:44 +0200 Subject: [PATCH 73/80] Stamp database on first creation --- ...ede4746_add_is_draft_status_to_queries_and_.py | 7 +++---- redash/cli/database.py | 15 +++++++++------ redash/models.py | 9 --------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py index c741420085..91398a1574 100644 --- a/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py +++ b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py @@ -7,7 +7,6 @@ """ from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. from sqlalchemy.exc import ProgrammingError @@ -19,7 +18,6 @@ def upgrade(): - op.get_bind() try: op.add_column('queries', sa.Column('is_draft', sa.Boolean, default=True, index=True)) op.add_column('dashboards', sa.Column('is_draft', sa.Boolean, default=True, index=True)) @@ -28,8 +26,9 @@ def upgrade(): except ProgrammingError as e: # The columns might exist if you ran the old migrations. if 'column "is_draft" of relation "queries" already exists' in e.message: - print "*** Skipping creationg of is_draft columns as they already exist." - op.execute("ROLLBACK") + print "Can't run this migration as you already have is_draft columns, please run:" + print "./manage.py db stamp {} # you might need to alter the command to match your environment.".format(revision) + exit() def downgrade(): diff --git a/redash/cli/database.py b/redash/cli/database.py index e291d81ce0..09d7b35800 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -1,20 +1,23 @@ from flask.cli import AppGroup - +from flask_migrate import stamp manager = AppGroup(help="Manage the database (create/drop tables).") @manager.command() def create_tables(): """Create the database tables.""" - from redash.models import db, create_db, init_db - create_db(True, True) + from redash.models import db, init_db + db.create_all() + + # Need to mark current DB as up to date + stamp() + init_db() - db.session.commit() @manager.command() def drop_tables(): """Drop the database tables.""" - from redash.models import create_db + from redash.models import db - create_db(False, True) + db.drop_all() diff --git a/redash/models.py b/redash/models.py index a3d628e156..9581e4229f 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1473,12 +1473,3 @@ def init_db(): db.session.commit() return default_org, admin_group, default_group - -def create_db(create_tables, drop_tables): - # TODO: use these methods directly - if drop_tables: - db.session.rollback() - db.drop_all() - - if create_tables: - db.create_all() From 106c7436476e736438aa2e9fa0fda15cc57d1b2f Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 15:36:08 +0200 Subject: [PATCH 74/80] Keep same logging format in ALembic --- migrations/alembic.ini | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/migrations/alembic.ini b/migrations/alembic.ini index f8ed4801f7..138c144473 100644 --- a/migrations/alembic.ini +++ b/migrations/alembic.ini @@ -41,5 +41,4 @@ level = NOTSET formatter = generic [formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S +format = [%(asctime)s][PID:%(process)d][%(levelname)s][%(name)s] %(message)s From c3805969305785fd83cd7a654c107666c74a69a9 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 16:02:51 +0200 Subject: [PATCH 75/80] Fix cases where we used User.groups instead of User.group_ids --- redash/handlers/alerts.py | 2 +- redash/handlers/dashboards.py | 7 ++++--- redash/handlers/groups.py | 14 ++++++-------- redash/handlers/queries.py | 9 ++++----- redash/models.py | 6 +++--- tests/handlers/test_alerts.py | 23 +++++++++++++++++++++++ tests/handlers/test_groups.py | 13 +++++++++++++ tests/models/test_queries.py | 33 ++++++++++++++++++++++++++++++++- 8 files changed, 86 insertions(+), 21 deletions(-) diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index 82abe07b88..34c7826c08 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -70,7 +70,7 @@ def post(self): @require_permission('list_alerts') def get(self): - return [alert.to_dict() for alert in models.Alert.all(groups=self.current_user.groups)] + return [alert.to_dict() for alert in models.Alert.all(group_ids=self.current_user.group_ids)] class AlertSubscriptionListResource(BaseResource): diff --git a/redash/handlers/dashboards.py b/redash/handlers/dashboards.py index ef2a35847a..9ed0c9124f 100644 --- a/redash/handlers/dashboards.py +++ b/redash/handlers/dashboards.py @@ -15,11 +15,11 @@ class RecentDashboardsResource(BaseResource): @require_permission('list_dashboards') def get(self): - recent = [d.to_dict() for d in models.Dashboard.recent(self.current_org, self.current_user.groups, self.current_user.id, for_user=True)] + recent = [d.to_dict() for d in models.Dashboard.recent(self.current_org, self.current_user.group_ids, self.current_user.id, for_user=True)] global_recent = [] if len(recent) < 10: - global_recent = [d.to_dict() for d in models.Dashboard.recent(self.current_org, self.current_user.groups, self.current_user.id)] + global_recent = [d.to_dict() for d in models.Dashboard.recent(self.current_org, self.current_user.group_ids, self.current_user.id)] return take(20, distinct(chain(recent, global_recent), key=lambda d: d['id'])) @@ -27,7 +27,7 @@ def get(self): class DashboardListResource(BaseResource): @require_permission('list_dashboards') def get(self): - results = models.Dashboard.all(self.current_org, self.current_user.groups, self.current_user) + results = models.Dashboard.all(self.current_org, self.current_user.group_ids, self.current_user) return [q.to_dict() for q in results] @require_permission('create_dashboard') @@ -42,6 +42,7 @@ def post(self): models.db.session.commit() return dashboard.to_dict() + class DashboardResource(BaseResource): @require_permission('list_dashboards') def get(self, dashboard_slug=None): diff --git a/redash/handlers/groups.py b/redash/handlers/groups.py index c43a5518f9..a4f21fe194 100644 --- a/redash/handlers/groups.py +++ b/redash/handlers/groups.py @@ -10,7 +10,8 @@ class GroupListResource(BaseResource): @require_admin def post(self): name = request.json['name'] - group = models.Group.create(name=name, org=self.current_org) + group = models.Group(name=name, org=self.current_org) + models.db.session.add(group) self.record_event({ 'action': 'create', @@ -40,7 +41,6 @@ def post(self, group_id): abort(400, message="Can't modify built-in groups.") group.name = request.json['name'] - group.save() self.record_event({ 'action': 'edit', @@ -52,7 +52,7 @@ def post(self, group_id): return group.to_dict() def get(self, group_id): - if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.groups): + if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.group_ids): abort(403) group = models.Group.get_by_id_and_org(group_id, self.current_org) @@ -75,8 +75,7 @@ def post(self, group_id): user_id = request.json['user_id'] user = models.User.get_by_id_and_org(user_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org) - user.groups.append(group.id) - user.save() + user.group_ids.append(group.id) self.record_event({ 'action': 'add_member', @@ -90,7 +89,7 @@ def post(self, group_id): @require_permission('list_users') def get(self, group_id): - if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.groups): + if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.group_ids): abort(403) members = models.Group.members(group_id) @@ -101,8 +100,7 @@ class GroupMemberResource(BaseResource): @require_admin def delete(self, group_id, user_id): user = models.User.get_by_id_and_org(user_id, self.current_org) - user.groups.remove(int(group_id)) - user.save() + user.group_ids.remove(int(group_id)) self.record_event({ 'action': 'remove_member', diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index a6e70ea1d1..b8855c37d6 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -32,18 +32,18 @@ class QuerySearchResource(BaseResource): def get(self): term = request.args.get('q', '') - return [q.to_dict(with_last_modified_by=False) for q in models.Query.search(term, self.current_user.groups)] + return [q.to_dict(with_last_modified_by=False) for q in models.Query.search(term, self.current_user.group_ids)] class QueryRecentResource(BaseResource): @require_permission('view_query') def get(self): - queries = models.Query.recent(self.current_user.groups, self.current_user.id) + queries = models.Query.recent(self.current_user.group_ids, self.current_user.id) recent = [d.to_dict(with_last_modified_by=False) for d in queries] global_recent = [] if len(recent) < 10: - global_recent = [d.to_dict(with_last_modified_by=False) for d in models.Query.recent(self.current_user.groups)] + global_recent = [d.to_dict(with_last_modified_by=False) for d in models.Query.recent(self.current_user.group_ids)] return take(20, distinct(chain(recent, global_recent), key=lambda d: d['id'])) @@ -81,8 +81,7 @@ def post(self): @require_permission('view_query') def get(self): - results = models.Query.all_queries([models.Group.query.get(g_id) - for g_id in self.current_user.group_ids]) + results = models.Query.all_queries(self.current_user.group_ids) page = request.args.get('page', 1, type=int) page_size = request.args.get('page_size', 25, type=int) return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False)) diff --git a/redash/models.py b/redash/models.py index 9581e4229f..542252aa7e 100644 --- a/redash/models.py +++ b/redash/models.py @@ -727,12 +727,12 @@ def create(cls, **kwargs): return query @classmethod - def all_queries(cls, groups, drafts=False): + def all_queries(cls, group_ids, drafts=False): q = (cls.query.join(User, Query.user_id == User.id) .outerjoin(QueryResult) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .filter(Query.is_archived == False) - .filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\ + .filter(DataSourceGroup.group_id.in_(group_ids))\ .group_by(Query.id, User.id, QueryResult.id, QueryResult.retrieved_at, QueryResult.runtime) .order_by(Query.created_at.desc())) @@ -745,7 +745,7 @@ def all_queries(cls, groups, drafts=False): @classmethod def by_user(cls, user, drafts): - return cls.all_queries(user.groups, drafts).filter(Query.user == user) + return cls.all_queries(user.group_ids, drafts).filter(Query.user == user) @classmethod def outdated_queries(cls): diff --git a/tests/handlers/test_alerts.py b/tests/handlers/test_alerts.py index 2f8f679a52..9d7c185f3b 100644 --- a/tests/handlers/test_alerts.py +++ b/tests/handlers/test_alerts.py @@ -63,6 +63,29 @@ def test_returns_404_for_unauthorized_users(self): self.assertEqual(rv.status_code, 404) +class TestAlertListGet(BaseTestCase): + def test_returns_all_alerts(self): + alert = self.factory.create_alert() + rv = self.make_request('get', "/api/alerts") + + self.assertEqual(rv.status_code, 200) + + alert_ids = [a['id'] for a in rv.json] + self.assertIn(alert.id, alert_ids) + + def test_returns_alerts_only_from_users_groups(self): + alert = self.factory.create_alert() + query = self.factory.create_query(data_source=self.factory.create_data_source(group=self.factory.create_group())) + alert2 = self.factory.create_alert(query_rel=query) + rv = self.make_request('get', "/api/alerts") + + self.assertEqual(rv.status_code, 200) + + alert_ids = [a['id'] for a in rv.json] + self.assertIn(alert.id, alert_ids) + self.assertNotIn(alert2.id, alert_ids) + + class TestAlertListPost(BaseTestCase): def test_returns_200_if_has_access_to_query(self): query = self.factory.create_query() diff --git a/tests/handlers/test_groups.py b/tests/handlers/test_groups.py index 5ddd44e366..9535d0d3df 100644 --- a/tests/handlers/test_groups.py +++ b/tests/handlers/test_groups.py @@ -57,6 +57,7 @@ def filtergroups(gs): self.factory.default_group, group1])) + class TestGroupResourcePost(BaseTestCase): def test_doesnt_change_builtin_groups(self): current_name = self.factory.default_group.name @@ -94,3 +95,15 @@ def test_can_delete_group_with_data_sources(self): self.assertEqual(response.status_code, 200) self.assertEqual(data_source, DataSource.query.get(data_source.id)) + + +class TestGroupResourceGet(BaseTestCase): + def test_returns_group(self): + rv = self.make_request('get', '/api/groups/{}'.format(self.factory.default_group.id)) + self.assertEqual(rv.status_code, 200) + + def test_doesnt_return_if_user_not_member_or_admin(self): + rv = self.make_request('get', '/api/groups/{}'.format(self.factory.admin_group.id)) + self.assertEqual(rv.status_code, 403) + + diff --git a/tests/models/test_queries.py b/tests/models/test_queries.py index af04e92970..2c612aa37d 100644 --- a/tests/models/test_queries.py +++ b/tests/models/test_queries.py @@ -2,8 +2,39 @@ from redash.models import Query, db -class TestApiKeyGetByObject(BaseTestCase): +class TestQueryByUser(BaseTestCase): + def test_returns_only_users_queries(self): + q = self.factory.create_query(user=self.factory.user) + q2 = self.factory.create_query(user=self.factory.create_user()) + queries = Query.by_user(self.factory.user, False) + + # not using self.assertIn/NotIn because otherwise this fails :O + self.assertTrue(q in queries) + self.assertFalse(q2 in queries) + + def test_returns_drafts_if_asked_to(self): + q = self.factory.create_query(is_draft=True) + q2 = self.factory.create_query(is_draft=False) + + queries = Query.by_user(self.factory.user, True) + + # not using self.assertIn/NotIn because otherwise this fails :O + self.assertTrue(q in queries) + self.assertFalse(q2 in queries) + + def test_returns_only_queries_from_groups_the_user_is_member_in(self): + q = self.factory.create_query() + q2 = self.factory.create_query(data_source=self.factory.create_data_source(group=self.factory.create_group())) + + queries = Query.by_user(self.factory.user, False) + + # not using self.assertIn/NotIn because otherwise this fails :O + self.assertTrue(q in queries) + self.assertFalse(q2 in queries) + + +class TestQueryFork(BaseTestCase): def assert_visualizations(self, origin_q, origin_v, forked_q, forked_v): self.assertEqual(origin_v.options, forked_v.options) self.assertEqual(origin_v.type, forked_v.type) From 1d18109964eef6347c188274bc7060513421ea6b Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 16:07:25 +0200 Subject: [PATCH 76/80] Fix tests that used Query.all_queries --- redash/handlers/widgets.py | 1 - redash/serializers.py | 2 +- tests/test_models.py | 12 ++++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/redash/handlers/widgets.py b/redash/handlers/widgets.py index 47ce325705..06d334a81b 100644 --- a/redash/handlers/widgets.py +++ b/redash/handlers/widgets.py @@ -61,7 +61,6 @@ def post(self, widget_id): require_object_modify_permission(widget.dashboard, self.current_user) widget_properties = request.get_json(force=True) widget.text = widget_properties['text'] - widget.save() return widget.to_dict() diff --git a/redash/serializers.py b/redash/serializers.py index 769a770eed..a9f3862cd6 100644 --- a/redash/serializers.py +++ b/redash/serializers.py @@ -19,7 +19,7 @@ def public_widget(widget): } if widget.visualization and widget.visualization.id: - query_data = models.QueryResult.get_by_id(widget.visualization.query.latest_query_data_id).to_dict() + query_data = models.QueryResult.query.get(widget.visualization.query.latest_query_data_id).to_dict() res['visualization'] = { 'type': widget.visualization.type, 'name': widget.visualization.name, diff --git a/tests/test_models.py b/tests/test_models.py index 0e6fdc5c16..a0a033ba30 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -262,12 +262,12 @@ def test_archived_query_doesnt_return_in_all(self): query.latest_query_data = query_result groups = list(models.Group.query.filter(models.Group.id.in_(query.groups))) - self.assertIn(query, list(models.Query.all_queries(groups))) + self.assertIn(query, list(models.Query.all_queries([g.id for g in groups]))) self.assertIn(query, models.Query.outdated_queries()) db.session.flush() query.archive() - self.assertNotIn(query, list(models.Query.all_queries(groups))) + self.assertNotIn(query, list(models.Query.all_queries([g.id for g in groups]))) self.assertNotIn(query, models.Query.outdated_queries()) def test_removes_associated_widgets_from_dashboards(self): @@ -428,10 +428,10 @@ def test_returns_only_queries_in_given_groups(self): models.DataSourceGroup(group=group2, data_source=ds2) ]) db.session.flush() - self.assertIn(q1, list(models.Query.all_queries([group1]))) - self.assertNotIn(q2, list(models.Query.all_queries([group1]))) - self.assertIn(q1, list(models.Query.all_queries([group1, group2]))) - self.assertIn(q2, list(models.Query.all_queries([group1, group2]))) + self.assertIn(q1, list(models.Query.all_queries([group1.id]))) + self.assertNotIn(q2, list(models.Query.all_queries([group1.id]))) + self.assertIn(q1, list(models.Query.all_queries([group1.id, group2.id]))) + self.assertIn(q2, list(models.Query.all_queries([group1.id, group2.id]))) class TestGroup(BaseTestCase): From 81fb139b88add47d4d41d2b4be36d76c4d72332d Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Thu, 8 Dec 2016 16:16:50 +0200 Subject: [PATCH 77/80] Add shell_context_processor to inject models to shell. --- redash/cli/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/redash/cli/__init__.py b/redash/cli/__init__.py index caaf37e895..7fb550eb31 100644 --- a/redash/cli/__init__.py +++ b/redash/cli/__init__.py @@ -13,6 +13,12 @@ def create(group): app = current_app or create_app() group.app = app + + @app.shell_context_processor + def shell_context(): + from redash import models + return dict(models=models) + return app From e524db0215f70d4dab0a1f998fa9852ce217d690 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Fri, 9 Dec 2016 14:40:00 -0600 Subject: [PATCH 78/80] Measure query time with statsd --- redash/metrics/database.py | 124 ++++++++++--------------------------- redash/models.py | 1 + 2 files changed, 32 insertions(+), 93 deletions(-) diff --git a/redash/metrics/database.py b/redash/metrics/database.py index a4ce6e312f..4cfb1ad83e 100644 --- a/redash/metrics/database.py +++ b/redash/metrics/database.py @@ -1,99 +1,37 @@ -from functools import wraps import time import logging -from playhouse.gfk import Model -from playhouse.postgres_ext import PostgresqlExtDatabase -from redash import statsd_client - -metrics_logger = logging.getLogger("metrics") - - -class MeteredPostgresqlExtDatabase(PostgresqlExtDatabase): - def __init__(self, *args, **kwargs): - self.query_count = 0 - self.query_duration = 0 - return super(MeteredPostgresqlExtDatabase, self).__init__(*args, **kwargs) - - def execute_sql(self, *args, **kwargs): - start_time = time.time() - - try: - result = super(MeteredPostgresqlExtDatabase, self).execute_sql(*args, **kwargs) - return result - finally: - self.query_count += 1 - # TODO: there is a noticeable few ms discrepancy between the duration here and the one calculated in - # metered_execute. Need to think what to do about it. - duration = (time.time() - start_time) * 1000 - self.query_duration += duration - - def reset_metrics(self): - # TODO: instead of manually managing reset of metrics, we should store them in a LocalProxy based object, that - # is guaranteed to be "replaced" when the current request is done. - self.query_count = 0 - self.query_duration = 0 - - -def patch_query_execute(): - real_execute = peewee.Query._execute - real_clone = peewee.Query.clone - - @wraps(real_execute) - def metered_execute(self, *args, **kwargs): - name = self.model_class.__name__ - action = getattr(self, 'model_action', 'unknown') +from sqlalchemy.engine import Engine - start_time = time.time() - try: - result = real_execute(self, *args, **kwargs) - return result - finally: - duration = (time.time() - start_time) * 1000 - statsd_client.timing('db.{}.{}'.format(name, action), duration) - metrics_logger.debug("model=%s query=%s duration=%.2f", name, action, duration) - - @wraps(real_clone) - def extended_clone(self): - cloned = real_clone(self) - setattr(cloned, 'model_action', getattr(self, 'model_action', 'unknown')) - return cloned - - peewee.Query._execute = metered_execute - peewee.Query.clone = extended_clone - - -class MeteredModel(Model): - @classmethod - def select(cls, *args, **kwargs): - return cls._execute_and_measure('select', args, kwargs) - - @classmethod - def update(cls, *args, **kwargs): - return cls._execute_and_measure('update', args, kwargs) - - @classmethod - def insert(cls, *args, **kwargs): - return cls._execute_and_measure('insert', args, kwargs) - - @classmethod - def insert_many(cls, *args, **kwargs): - return cls._execute_and_measure('insert_many', args, kwargs) - - @classmethod - def insert_from(cls, *args, **kwargs): - return cls._execute_and_measure('insert_from', args, kwargs) - - @classmethod - def delete(cls, *args, **kwargs): - return cls._execute_and_measure('delete', args, kwargs) +from redash import statsd_client - @classmethod - def raw(cls, *args, **kwargs): - return cls._execute_and_measure('raw', args, kwargs) +metrics_logger = logging.getLogger("metrics") - @classmethod - def _execute_and_measure(cls, action, args, kwargs): - result = getattr(super(MeteredModel, cls), action)(*args, **kwargs) - setattr(result, 'model_action', action) - return result +from sqlalchemy.orm.util import _ORMJoin +from sqlalchemy.event import listens_for + + +@listens_for(Engine, "before_execute") +def before_execute(conn, elt, multiparams, params): + conn.info.setdefault('query_start_time', []).append(time.time()) + + +@listens_for(Engine, "after_execute") +def after_execute(conn, elt, multiparams, params, result): + duration = time.time() - conn.info['query_start_time'].pop(-1) + action = elt.__class__.__name__ + + if action == 'Select': + t = elt.froms[0] + while isinstance(t, _ORMJoin): + t = t.left + name = t.name + elif action in ['Update', 'Insert', 'Delete']: + name = elt.table.name + else: + # create/drop tables, sqlalchemy internal schema queries, etc + return + statsd_client.timing('db.{}.{}'.format(name, action), duration) + metrics_logger.debug("model=%s query=%s duration=%.2f", name, action, + duration) + return result diff --git a/redash/models.py b/redash/models.py index 542252aa7e..84f8a3e773 100644 --- a/redash/models.py +++ b/redash/models.py @@ -25,6 +25,7 @@ from redash.query_runner import get_query_runner, get_configuration_schema_for_query_runner_type from redash.utils import generate_token, json_dumps from redash.utils.configuration import ConfigurationContainer +from redash.metrics import database db = SQLAlchemy() Column = functools.partial(db.Column, nullable=False) From 1978e077487a75c9801cc113503cabac6499bb0d Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 11 Dec 2016 15:11:30 +0200 Subject: [PATCH 79/80] Use group ids instead of groups in Queries.search/recent --- redash/models.py | 8 ++++---- tests/test_models.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/redash/models.py b/redash/models.py index 84f8a3e773..701f4f7cf2 100644 --- a/redash/models.py +++ b/redash/models.py @@ -765,7 +765,7 @@ def outdated_queries(cls): return outdated_queries.values() @classmethod - def search(cls, term, groups): + def search(cls, term, group_ids): # TODO: This is very naive implementation of search, to be replaced with PostgreSQL full-text-search solution. where = (Query.name.like(u"%{}%".format(term)) | Query.description.like(u"%{}%".format(term))) @@ -774,7 +774,7 @@ def search(cls, term, groups): where |= Query.id == term where &= Query.is_archived == False - where &= DataSourceGroup.group_id.in_([g.id for g in groups]) + where &= DataSourceGroup.group_id.in_(group_ids) query_ids = ( db.session.query(Query.id).join( DataSourceGroup, @@ -785,7 +785,7 @@ def search(cls, term, groups): Query.id.in_(query_ids)) @classmethod - def recent(cls, groups, user_id=None, limit=20): + def recent(cls, group_ids, user_id=None, limit=20): query = (cls.query.join(User, Query.user_id == User.id) .filter(Event.created_at > (db.func.current_date() - 7)) .join(Event, Query.id == Event.object_id.cast(db.Integer)) @@ -795,7 +795,7 @@ def recent(cls, groups, user_id=None, limit=20): 'edit_description', 'view_source']), Event.object_id != None, Event.object_type == 'query', - DataSourceGroup.group_id.in_([g.id for g in groups]), + DataSourceGroup.group_id.in_(group_ids), Query.is_draft == False, Query.is_archived == False) .group_by(Event.object_id, Query.id, User.id) diff --git a/tests/test_models.py b/tests/test_models.py index a0a033ba30..482e28b4f5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -39,7 +39,7 @@ def test_search_finds_in_name(self): q1 = self.factory.create_query(name=u"Testing seåřċħ") q2 = self.factory.create_query(name=u"Testing seåřċħing") q3 = self.factory.create_query(name=u"Testing seå řċħ") - queries = list(models.Query.search(u"seåřċħ", [self.factory.default_group])) + queries = list(models.Query.search(u"seåřċħ", [self.factory.default_group.id])) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -50,7 +50,7 @@ def test_search_finds_in_description(self): q2 = self.factory.create_query(description=u"Testing seåřċħing") q3 = self.factory.create_query(description=u"Testing seå řċħ") - queries = models.Query.search(u"seåřċħ", [self.factory.default_group]) + queries = models.Query.search(u"seåřċħ", [self.factory.default_group.id]) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -61,7 +61,7 @@ def test_search_by_id_returns_query(self): q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") db.session.flush() - queries = models.Query.search(str(q3.id), [self.factory.default_group]) + queries = models.Query.search(str(q3.id), [self.factory.default_group.id]) self.assertIn(q3, queries) self.assertNotIn(q1, queries) @@ -76,18 +76,18 @@ def test_search_respects_groups(self): q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") - queries = list(models.Query.search("Testing", [self.factory.default_group])) + queries = list(models.Query.search("Testing", [self.factory.default_group.id])) self.assertNotIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = list(models.Query.search("Testing", [other_group, self.factory.default_group])) + queries = list(models.Query.search("Testing", [other_group.id, self.factory.default_group.id])) self.assertIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = list(models.Query.search("Testing", [other_group])) + queries = list(models.Query.search("Testing", [other_group.id])) self.assertIn(q1, queries) self.assertNotIn(q2, queries) self.assertNotIn(q3, queries) @@ -100,7 +100,7 @@ def test_returns_each_query_only_once(self): q1 = self.factory.create_query(description="Testing search", data_source=ds) db.session.flush() - queries = list(models.Query.search("Testing", [self.factory.default_group, other_group, second_group])) + queries = list(models.Query.search("Testing", [self.factory.default_group.id, other_group.id, second_group.id])) self.assertEqual(1, len(queries)) From b3bfc3bc74f89d98759b36d20702f140a98e884b Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 11 Dec 2016 15:41:07 +0200 Subject: [PATCH 80/80] Bring back support for total query time/queries count --- redash/metrics/database.py | 21 ++++++++++++++++----- redash/metrics/request.py | 14 +++++++------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/redash/metrics/database.py b/redash/metrics/database.py index 4cfb1ad83e..405873a8c4 100644 --- a/redash/metrics/database.py +++ b/redash/metrics/database.py @@ -2,14 +2,15 @@ import logging from sqlalchemy.engine import Engine +from sqlalchemy.orm.util import _ORMJoin +from sqlalchemy.event import listens_for + +from flask import has_request_context, g from redash import statsd_client metrics_logger = logging.getLogger("metrics") -from sqlalchemy.orm.util import _ORMJoin -from sqlalchemy.event import listens_for - @listens_for(Engine, "before_execute") def before_execute(conn, elt, multiparams, params): @@ -18,7 +19,7 @@ def before_execute(conn, elt, multiparams, params): @listens_for(Engine, "after_execute") def after_execute(conn, elt, multiparams, params, result): - duration = time.time() - conn.info['query_start_time'].pop(-1) + duration = 1000 * (time.time() - conn.info['query_start_time'].pop(-1)) action = elt.__class__.__name__ if action == 'Select': @@ -31,7 +32,17 @@ def after_execute(conn, elt, multiparams, params, result): else: # create/drop tables, sqlalchemy internal schema queries, etc return + + action = action.lower() + statsd_client.timing('db.{}.{}'.format(name, action), duration) - metrics_logger.debug("model=%s query=%s duration=%.2f", name, action, + metrics_logger.debug("table=%s query=%s duration=%.2f", name, action, duration) + + if has_request_context(): + g.setdefault('queries_count', 0) + g.setdefault('queries_duration', 0) + g.queries_count += 1 + g.queries_duration += duration + return result diff --git a/redash/metrics/request.py b/redash/metrics/request.py index b6ceada970..aa667c35b6 100644 --- a/redash/metrics/request.py +++ b/redash/metrics/request.py @@ -4,7 +4,6 @@ from flask import g, request from redash import statsd_client -from redash.models import db metrics_logger = logging.getLogger("metrics") @@ -18,6 +17,8 @@ def calculate_metrics(response): return response request_duration = (time.time() - g.start_time) * 1000 + queries_duration = g.get('queries_duration', 0.0) + queries_count = g.get('queries_count', 0.0) metrics_logger.info("method=%s path=%s endpoint=%s status=%d content_type=%s content_length=%d duration=%.2f query_count=%d query_duration=%.2f", request.method, @@ -27,8 +28,8 @@ def calculate_metrics(response): response.content_type, response.content_length, request_duration, - # XXX instrument SQLA for metrics - None, None) + queries_count, + queries_duration) statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration) @@ -43,7 +44,6 @@ def calculate_metrics_on_exception(error): def provision_app(app): - # app.before_request(record_requets_start_time) - # app.after_request(calculate_metrics) - # app.teardown_request(calculate_metrics_on_exception) - pass + app.before_request(record_requets_start_time) + app.after_request(calculate_metrics) + app.teardown_request(calculate_metrics_on_exception)