diff --git a/redash/__init__.py b/redash/__init__.py index dabbd39ee6..3b0a24cae6 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -7,6 +7,8 @@ from werkzeug.routing import BaseConverter, ValidationError from statsd import StatsClient from flask_mail import Mail +from flask_limiter import Limiter +from flask_limiter.util import get_ipaddr from redash import settings from redash.query_runner import import_query_runners @@ -52,6 +54,7 @@ def create_redis_connection(): mail = Mail() mail.init_mail(settings.all_settings()) statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX) +limiter = Limiter(key_func=get_ipaddr, storage_uri=settings.REDIS_URL) import_query_runners(settings.QUERY_RUNNERS) import_destinations(settings.DESTINATIONS) @@ -112,5 +115,6 @@ def create_app(): mail.init_app(app) setup_authentication(app) handlers.init_app(app) + limiter.init_app(app) return app diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index e7e98cf7ac..0600dc2adc 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -2,7 +2,7 @@ from flask import render_template, request, redirect, url_for, flash from flask_login import current_user, login_user, logout_user -from redash import models, settings +from redash import models, settings, limiter from redash.handlers import routes from redash.handlers.base import org_scoped_rule from redash.authentication import current_org, get_login_url @@ -81,6 +81,7 @@ def forgot_password(org_slug=None): @routes.route(org_scoped_rule('/login'), methods=['GET', 'POST']) +@limiter.limit(settings.THROTTLE_LOGIN_PATTERN) def login(org_slug=None): index_url = url_for("redash.index", org_slug=org_slug) next_path = request.args.get('next', index_url) diff --git a/redash/settings.py b/redash/settings.py index f597c9495a..35b959c384 100644 --- a/redash/settings.py +++ b/redash/settings.py @@ -146,6 +146,11 @@ def all_settings(): ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE = os.environ.get('REDASH_ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE', "({state}) {alert_name}") +# How many requests are allowed per IP to the login page before +# being throttled? +# See https://flask-limiter.readthedocs.io/en/stable/#rate-limit-string-notation +THROTTLE_LOGIN_PATTERN = os.environ.get('REDASH_THROTTLE_LOGIN_PATTERN', '50/hour') + # CORS settings for the Query Result API (and possbily future external APIs). # In most cases all you need to do is set REDASH_CORS_ACCESS_CONTROL_ALLOW_ORIGIN # to the calling domain (or domains in a comma separated list). diff --git a/requirements.txt b/requirements.txt index 2115763bea..2ccc3df6ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,4 @@ xlsxwriter==0.9.3 pystache==0.5.4 parsedatetime==2.1 cryptography==1.4 +Flask-Limiter==0.9.3 diff --git a/tests/handlers/test_authentication.py b/tests/handlers/test_authentication.py index 132c8d40d3..8896aee468 100644 --- a/tests/handlers/test_authentication.py +++ b/tests/handlers/test_authentication.py @@ -1,6 +1,7 @@ from tests import BaseTestCase import mock import time +from redash import settings from redash.models import User from redash.authentication.account import invite_token from tests.handlers import get_request, post_request @@ -55,3 +56,13 @@ def test_valid_password(self): user = User.get_by_id(self.factory.user.id) self.assertTrue(user.verify_password(password)) + +class TestLogin(BaseTestCase): + def test_throttle_login(self): + # Extract the limit from settings (ex: '50/day') + limit = settings.THROTTLE_LOGIN_PATTERN.split('/')[0] + for _ in range(0, int(limit)): + get_request('/login', org=self.factory.org) + + response = get_request('/login', org=self.factory.org) + self.assertEqual(response.status_code, 429)