diff --git a/explorer/exporters.py b/explorer/exporters.py index 6f560cc2..0450a291 100644 --- a/explorer/exporters.py +++ b/explorer/exporters.py @@ -1,3 +1,4 @@ +import codecs import csv import json import string @@ -61,6 +62,7 @@ def _get_output(self, res, **kwargs): delim = '\t' if delim == 'tab' else str(delim) delim = app_settings.CSV_DELIMETER if len(delim) > 1 else delim csv_data = StringIO() + csv_data.write(codecs.BOM_UTF8.decode('utf-8')) writer = csv.writer(csv_data, delimiter=delim) writer.writerow(res.headers) for row in res.data: diff --git a/explorer/tests/test_actions.py b/explorer/tests/test_actions.py index fd4d91cc..6fd40b76 100644 --- a/explorer/tests/test_actions.py +++ b/explorer/tests/test_actions.py @@ -11,12 +11,12 @@ class TestSqlQueryActions(TestCase): def test_single_query_is_csv_file(self): - expected_csv = b'two\r\n2\r\n' + expected_csv = 'two\r\n2\r\n' r = SimpleQueryFactory() fn = generate_report_action() result = fn(None, None, [r, ]) - self.assertEqual(result.content.lower(), expected_csv) + self.assertEqual(result.content.lower().decode('utf-8-sig'), expected_csv) def test_multiple_queries_are_zip_file(self): @@ -32,7 +32,7 @@ def test_multiple_queries_are_zip_file(self): self.assertEqual(len(z.namelist()), 2) self.assertEqual(z.namelist()[0], f'{q.title}.csv') - self.assertEqual(got_csv.lower().decode('utf-8'), expected_csv) + self.assertEqual(got_csv.lower().decode('utf-8-sig'), expected_csv) # if commas are not removed from the filename, then Chrome throws # "duplicate headers received from server" diff --git a/explorer/tests/test_exporters.py b/explorer/tests/test_exporters.py index 1d89bd0a..a3203491 100644 --- a/explorer/tests/test_exporters.py +++ b/explorer/tests/test_exporters.py @@ -25,13 +25,25 @@ def test_writing_unicode(self): res._data = [[1, None], ["Jenét", '1']] res = CSVExporter(query=None)._get_output(res).getvalue() - self.assertEqual(res, 'a,\r\n1,\r\nJenét,1\r\n') + self.assertEqual( + res.encode('utf-8').decode('utf-8-sig'), + 'a,\r\n1,\r\nJenét,1\r\n' + ) def test_custom_delimiter(self): q = SimpleQueryFactory(sql='select 1, 2') exporter = CSVExporter(query=q) res = exporter.get_output(delim='|') - self.assertEqual(res, '1|2\r\n1|2\r\n') + self.assertEqual( + res.encode('utf-8').decode('utf-8-sig'), + '1|2\r\n1|2\r\n' + ) + + def test_writing_bom(self): + q = SimpleQueryFactory(sql='select 1, 2') + exporter = CSVExporter(query=q) + res = exporter.get_output() + self.assertEqual(res, '\ufeff1,2\r\n1,2\r\n') class TestJson(TestCase): diff --git a/explorer/tests/test_tasks.py b/explorer/tests/test_tasks.py index 6a89ba5d..a6df715c 100644 --- a/explorer/tests/test_tasks.py +++ b/explorer/tests/test_tasks.py @@ -35,7 +35,10 @@ def test_async_results(self, mocked_upload): ) self.assertIn('[SQL Explorer] Report ', mail.outbox[1].subject) self.assertEqual( - mocked_upload.call_args[0][1].getvalue(), output.getvalue() + mocked_upload + .call_args[0][1].getvalue() + .encode('utf-8').decode('utf-8-sig'), + output.getvalue() ) self.assertEqual(mocked_upload.call_count, 1) diff --git a/explorer/tests/test_utils.py b/explorer/tests/test_utils.py index 279ea601..9ff06ebf 100644 --- a/explorer/tests/test_utils.py +++ b/explorer/tests/test_utils.py @@ -28,13 +28,13 @@ def test_overriding_blacklist(self): r = SimpleQueryFactory(sql="SELECT 1+1 AS \"DELETE\";") fn = generate_report_action() result = fn(None, None, [r, ]) - self.assertEqual(result.content, b'DELETE\r\n2\r\n') + self.assertEqual(result.content.decode('utf-8-sig'), 'DELETE\r\n2\r\n') def test_default_blacklist_prevents_deletes(self): r = SimpleQueryFactory(sql="SELECT 1+1 AS \"DELETE\";") fn = generate_report_action() result = fn(None, None, [r, ]) - self.assertEqual(result.content.decode('utf-8'), '0') + self.assertEqual(result.content.decode('utf-8-sig'), '0') def test_queries_deleting_stuff_are_not_ok(self): sql = "'distraction'; deLeTe from table; " \ diff --git a/explorer/tests/test_views.py b/explorer/tests/test_views.py index df3fe7dd..3141602a 100644 --- a/explorer/tests/test_views.py +++ b/explorer/tests/test_views.py @@ -537,7 +537,7 @@ def test_sql_download_csv_with_custom_delim(self): self.assertEqual(response.status_code, 200) self.assertEqual(response['content-type'], 'text/csv') - self.assertEqual(response.content.decode('utf-8'), '1|2\r\n1|2\r\n') + self.assertEqual(response.content.decode('utf-8-sig'), '1|2\r\n1|2\r\n') def test_sql_download_csv_with_tab_delim(self): url = reverse("download_sql") + '?format=csv&delim=tab' @@ -546,7 +546,7 @@ def test_sql_download_csv_with_tab_delim(self): self.assertEqual(response.status_code, 200) self.assertEqual(response['content-type'], 'text/csv') - self.assertEqual(response.content.decode('utf-8'), '1\t2\r\n1\t2\r\n') + self.assertEqual(response.content.decode('utf-8-sig'), '1\t2\r\n1\t2\r\n') def test_sql_download_csv_with_bad_delim(self): url = reverse("download_sql") + '?format=csv&delim=foo' @@ -555,7 +555,7 @@ def test_sql_download_csv_with_bad_delim(self): self.assertEqual(response.status_code, 200) self.assertEqual(response['content-type'], 'text/csv') - self.assertEqual(response.content.decode('utf-8'), '1,2\r\n1,2\r\n') + self.assertEqual(response.content.decode('utf-8-sig'), '1,2\r\n1,2\r\n') def test_sql_download_json(self): url = reverse("download_sql") + '?format=json' diff --git a/test_project/start.sh b/test_project/start.sh index 17b07bd3..be58c9f1 100644 --- a/test_project/start.sh +++ b/test_project/start.sh @@ -1,5 +1,5 @@ -pip install -r requirements.txt -pip install -r optional-requirements.txt +pip install -r requirements/base.txt +pip install -r requirements/optional.txt python manage.py migrate python manage.py shell <