Skip to content

Commit

Permalink
Update boto and celery integration (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
lawson89 authored Dec 4, 2022
1 parent a04764f commit 0d43898
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 55 deletions.
2 changes: 1 addition & 1 deletion docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Snapshots
'schedule': crontab(hour=1, minute=0)
}
- Requires celery, obviously. Also uses djcelery and tinys3. All
- Requires celery, obviously. Also uses boto3. All
of these deps are optional and can be installed with
``pip install -r optional-requirements.txt``
- The checkbox for opting a query into a snapshot is ALL THE WAY
Expand Down
11 changes: 11 additions & 0 deletions docs/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ S3 Bucket for snapshot upload
EXPLORER_S3_BUCKET = None
S3 link expiration
******************

S3 link expiration time. Defaults to 3600 seconds (1hr) if not specified.
Links are generated as presigned urls for security

.. code-block:: python
EXPLORER_S3_S3_LINK_EXPIRATION = 3600
From email
**********

Expand Down
1 change: 1 addition & 0 deletions explorer/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
S3_ACCESS_KEY = getattr(settings, "EXPLORER_S3_ACCESS_KEY", None)
S3_SECRET_KEY = getattr(settings, "EXPLORER_S3_SECRET_KEY", None)
S3_BUCKET = getattr(settings, "EXPLORER_S3_BUCKET", None)
S3_LINK_EXPIRATION: int = getattr(settings, "EXPLORER_S3_S3_LINK_EXPIRATION", 3600)
FROM_EMAIL = getattr(
settings, 'EXPLORER_FROM_EMAIL', 'django-sql-explorer@example.com'
)
Expand Down
14 changes: 7 additions & 7 deletions explorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from explorer import app_settings
from explorer.utils import (
extract_params, get_params_for_url, get_s3_bucket, get_valid_connection, passes_blacklist, shared_dict_update,
swap_params,
extract_params, get_params_for_url, get_s3_bucket, get_valid_connection, passes_blacklist, s3_url,
shared_dict_update, swap_params,
)


Expand Down Expand Up @@ -133,13 +133,13 @@ def shared(self):
def snapshots(self):
if app_settings.ENABLE_TASKS:
b = get_s3_bucket()
keys = b.list(prefix=f'query-{self.id}/snap-')
keys_s = sorted(keys, key=lambda k: k.last_modified)
objects = b.objects.filter(Prefix=f'query-{self.id}/snap-')
objects_s = sorted(objects, key=lambda k: k.last_modified)
return [
SnapShot(
k.generate_url(expires_in=0, query_auth=False),
k.last_modified
) for k in keys_s
s3_url(b, o.key),
o.last_modified
) for o in objects_s
]


Expand Down
33 changes: 21 additions & 12 deletions explorer/tasks.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import io
import random
import string
from datetime import date, datetime, timedelta

from django.core.cache import cache
from django.core.mail import send_mail
from django.db import DatabaseError

from explorer import app_settings
from explorer.exporters import get_exporter_class
from explorer.models import Query, QueryLog


if app_settings.ENABLE_TASKS:
from celery import task
from celery import shared_task
from celery.utils.log import get_task_logger

from explorer.utils import s3_upload
logger = get_task_logger(__name__)
else:
import logging

from explorer.utils import noop_decorator as task
from explorer.utils import noop_decorator as shared_task
logger = logging.getLogger(__name__)


@task
@shared_task
def execute_query(query_id, email_address):
q = Query.objects.get(pk=query_id)
send_mail('[SQL Explorer] Your query is running...',
Expand All @@ -39,17 +39,26 @@ def execute_query(query_id, email_address):
) for _ in range(20)
)
try:
url = s3_upload(f'{random_part}.csv', exporter.get_file_output())
url = s3_upload(f'{random_part}.csv', convert_csv_to_bytesio(exporter))
subj = f'[SQL Explorer] Report "{q.title}" is ready'
msg = f'Download results:\n\r{url}'
except DatabaseError as e:
except Exception as e:
subj = f'[SQL Explorer] Error running report {q.title}'
msg = f'Error: {e}\nPlease contact an administrator'
logger.warning(f'{subj}: {e}')
logger.exception(f'{subj}: {e}')
send_mail(subj, msg, app_settings.FROM_EMAIL, [email_address])


@task
# I am sure there is a much more efficient way to do this but boto3 expects a binary file basically
def convert_csv_to_bytesio(csv_exporter):
csv_file_io = csv_exporter.get_file_output()
csv_file_io.seek(0)
csv_data: str = csv_file_io.read()
bio = io.BytesIO(bytes(csv_data, 'utf-8'))
return bio


@shared_task
def snapshot_query(query_id):
try:
logger.info(f"Starting snapshot for query {query_id}...")
Expand All @@ -60,7 +69,7 @@ def snapshot_query(query_id):
date.today().strftime('%Y%m%d-%H:%M:%S')
)
logger.info(f"Uploading snapshot for query {query_id} as {k}...")
url = s3_upload(k, exporter.get_file_output())
url = s3_upload(k, convert_csv_to_bytesio(exporter))
logger.info(
f"Done uploading snapshot for query {query_id}. URL: {url}"
)
Expand All @@ -71,7 +80,7 @@ def snapshot_query(query_id):
snapshot_query.retry()


@task
@shared_task
def snapshot_queries():
logger.info("Starting query snapshots...")
qs = Query.objects.filter(snapshot=True).values_list('id', flat=True)
Expand All @@ -83,7 +92,7 @@ def snapshot_queries():
logger.info("Done creating tasks.")


@task
@shared_task
def truncate_querylogs(days):
qs = QueryLog.objects.filter(
run_at__lt=datetime.now() - timedelta(days=days)
Expand All @@ -95,7 +104,7 @@ def truncate_querylogs(days):
logger.info('Done deleting QueryLog objects.')


@task
@shared_task
def build_schema_cache_async(connection_alias):
from .schema import build_schema_info, connection_schema_cache_key
ret = build_schema_info(connection_alias)
Expand Down
22 changes: 12 additions & 10 deletions explorer/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,25 @@ def test_log_saves_duration(self):
log = QueryLog.objects.first()
self.assertEqual(log.duration, res.duration)

@patch('explorer.models.s3_url')
@patch('explorer.models.get_s3_bucket')
def test_get_snapshots_sorts_snaps(self, mocked_conn):
conn = Mock()
conn.list = Mock()
def test_get_snapshots_sorts_snaps(self, mocked_get_s3_bucket, mocked_s3_url):
bucket = Mock()
bucket.objects.filter = Mock()
k1 = Mock()
k1.generate_url.return_value = 'http://s3.com/foo'
k1.key = 'foo'
k1.last_modified = 'b'
k2 = Mock()
k2.generate_url.return_value = 'http://s3.com/bar'
k2.key = 'bar'
k2.last_modified = 'a'
conn.list.return_value = [k1, k2]
mocked_conn.return_value = conn
bucket.objects.filter.return_value = [k1, k2]
mocked_get_s3_bucket.return_value = bucket
mocked_s3_url.return_value = 'http://s3.com/presigned_url'
q = SimpleQueryFactory()
snaps = q.snapshots
self.assertEqual(conn.list.call_count, 1)
self.assertEqual(snaps[0].url, 'http://s3.com/bar')
conn.list.assert_called_once_with(prefix=f'query-{q.id}/snap-')
self.assertEqual(bucket.objects.filter.call_count, 1)
self.assertEqual(snaps[0].url, 'http://s3.com/presigned_url')
bucket.objects.filter.assert_called_once_with(Prefix=f'query-{q.id}/snap-')

def test_final_sql_uses_merged_params(self):
q = SimpleQueryFactory(sql="select '$$foo:bar$$', '$$qux$$';")
Expand Down
5 changes: 4 additions & 1 deletion explorer/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from io import StringIO
from unittest.mock import patch
Expand Down Expand Up @@ -31,7 +32,9 @@ def test_async_results(self, mocked_upload):
)
self.assertIn('[SQL Explorer] Report ', mail.outbox[1].subject)
self.assertEqual(
mocked_upload.call_args[0][1].getvalue().encode('utf-8').decode('utf-8-sig'),
mocked_upload
.call_args[0][1].getvalue()
.decode('utf-8-sig'),
output.getvalue()
)
self.assertEqual(mocked_upload.call_count, 1)
Expand Down
4 changes: 2 additions & 2 deletions explorer/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,14 @@ def test_user_query_views(self):
@patch('explorer.models.get_s3_bucket')
def test_query_snapshot_renders(self, mocked_conn):
conn = Mock()
conn.list = Mock()
conn.objects.filter = Mock()
k1 = Mock()
k1.generate_url.return_value = 'http://s3.com/foo'
k1.last_modified = '2015-01-01'
k2 = Mock()
k2.generate_url.return_value = 'http://s3.com/bar'
k2.last_modified = '2015-01-02'
conn.list.return_value = [k1, k2]
conn.objects.filter.return_value = [k1, k2]
mocked_conn.return_value = conn

query = SimpleQueryFactory(sql="select 1;", snapshot=True)
Expand Down
27 changes: 15 additions & 12 deletions explorer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.contrib.auth.forms import AuthenticationForm
from django.contrib.auth.views import LoginView

import boto3
import sqlparse
from sqlparse import format as sql_format
from sqlparse.sql import Token, TokenList
Expand Down Expand Up @@ -186,19 +187,21 @@ def get_valid_connection(alias=None):


def get_s3_bucket():
from boto.s3.connection import S3Connection

conn = S3Connection(app_settings.S3_ACCESS_KEY,
app_settings.S3_SECRET_KEY)
return conn.get_bucket(app_settings.S3_BUCKET)
s3 = boto3.resource('s3',
aws_access_key_id=app_settings.S3_ACCESS_KEY,
aws_secret_access_key=app_settings.S3_SECRET_KEY)
return s3.Bucket(name=app_settings.S3_BUCKET)


def s3_upload(key, data):
from boto.s3.key import Key
bucket = get_s3_bucket()
k = Key(bucket)
k.key = key
k.set_contents_from_file(data, rewind=True)
k.set_acl('public-read')
k.set_metadata('Content-Type', 'text/csv')
return k.generate_url(expires_in=0, query_auth=False)
bucket.upload_fileobj(data, key, ExtraArgs={'ContentType': "text/csv"})
return s3_url(bucket, key)


def s3_url(bucket, key):
url = bucket.meta.client.generate_presigned_url(
ClientMethod='get_object',
Params={'Bucket': app_settings.S3_BUCKET, 'Key': key},
ExpiresIn=app_settings.S3_LINK_EXPIRATION)
return url
6 changes: 3 additions & 3 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
celery>=3.1,<4.0
boto>=2.49
django-celery>=3.3.1
importlib-metadata<5.0; python_version <= '3.7'
celery>=4.0
boto3>=1.20.0
xlsxwriter>=1.3.6
factory-boy>=3.1.0
matplotlib<4
Expand Down
3 changes: 3 additions & 0 deletions test_project/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .celery_config import app as celery_app

__all__ = ['celery_app']
17 changes: 17 additions & 0 deletions test_project/celery_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

from celery import Celery

# Set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings')

app = Celery('test_project')

# Using a string here means the worker doesn't have to serialize
# the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix.
app.config_from_object('django.conf:settings', namespace='CELERY')

# Load task modules from all registered Django apps.
app.autodiscover_tasks()
10 changes: 3 additions & 7 deletions test_project/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

import djcelery


SECRET_KEY = 'shhh'
DEBUG = True
Expand Down Expand Up @@ -70,7 +68,6 @@
'django.contrib.staticfiles',
'django.contrib.admin',
'explorer',
'djcelery'
)

AUTHENTICATION_BACKENDS = (
Expand All @@ -85,11 +82,10 @@
'django.contrib.messages.middleware.MessageMiddleware',
]

TEST_RUNNER = 'djcelery.contrib.test_runner.CeleryTestSuiteRunner'
CELERY_TASK_ALWAYS_EAGER = True

djcelery.setup_loader()
CELERY_ALWAYS_EAGER = True
BROKER_BACKEND = 'memory'
# added to help debug tasks
EMAIL_BACKEND = 'django.core.mail.backends.console.EmailBackend'

# Explorer-specific

Expand Down

0 comments on commit 0d43898

Please sign in to comment.