Skip to content

Commit

Permalink
Nets Feature: SQLLab Detokenisation and Refactor Backend Detokenisati…
Browse files Browse the repository at this point in the history
…on (#194)
  • Loading branch information
RossMoir authored Sep 12, 2022
1 parent 42e0b44 commit 464dcef
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 30 deletions.
5 changes: 4 additions & 1 deletion superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,10 @@ export function runQuery(query) {
return SupersetClient.post({
endpoint: `/superset/sql_json/${search}`,
body: JSON.stringify(postPayload),
headers: { 'Content-Type': 'application/json' },
headers: {
'Content-Type': 'application/json',
...(query.detokenisation && { DETOKENISATION: 'True' }),
},
parseMethod: 'text',
})
.then(({ text = '{}' }) => {
Expand Down
26 changes: 26 additions & 0 deletions superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ class SqlEditor extends React.PureComponent {
LocalStorageKeys.sqllab__is_autocomplete_enabled,
true,
),
detokenisationEnabled: getItem(
LocalStorageKeys.sqllab__is_detokenisation_enabled,
false,
),
showCreateAsModal: false,
createAs: '',
};
Expand Down Expand Up @@ -396,6 +400,18 @@ class SqlEditor extends React.PureComponent {
});
};

handleToggleDetokenisationEnabled = () => {
this.setState(prevState => {
setItem(
LocalStorageKeys.sqllab__is_detokenisation_enabled,
!prevState.detokenisationEnabled,
);
return {
detokenisationEnabled: !prevState.detokenisationEnabled,
};
});
};

handleWindowResize() {
this.setState({ height: this.getSqlEditorHeight() });
}
Expand Down Expand Up @@ -458,6 +474,7 @@ class SqlEditor extends React.PureComponent {
ctas,
ctas_method,
updateTabState: !qe.selectedText,
detokenisation: this.state.detokenisationEnabled,
};
this.props.runQuery(query);
this.props.setActiveSouthPaneTab('Results');
Expand Down Expand Up @@ -510,6 +527,7 @@ class SqlEditor extends React.PureComponent {
<AceEditorWrapper
actions={this.props.actions}
autocomplete={this.state.autocompleteEnabled}
detokenisation={this.state.detokenisationEnabled}
onBlur={this.setQueryEditorSql}
onChange={this.onSqlChanged}
queryEditor={this.props.queryEditor}
Expand Down Expand Up @@ -696,6 +714,14 @@ class SqlEditor extends React.PureComponent {
</AntdDropdown>
</LimitSelectStyled>
</span>
<span>
<span>Detokenisation:</span>{' '}
<AntdSwitch
checked={this.state.detokenisationEnabled}
onChange={this.handleToggleDetokenisationEnabled}
name="detokenisation-switch"
/>
</span>
{this.props.latestQuery && (
<Timer
startTime={this.props.latestQuery.startDttm}
Expand Down
2 changes: 2 additions & 0 deletions superset-frontend/src/utils/localStorageHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export enum LocalStorageKeys {
* sqllab__is_autocomplete_enabled
*/
sqllab__is_autocomplete_enabled = 'sqllab__is_autocomplete_enabled',
sqllab__is_detokenisation_enabled = 'sqllab__is_detokenisation_enabled',
explore__data_table_time_formatted_columns = 'explore__data_table_time_formatted_columns',
}

Expand All @@ -63,6 +64,7 @@ export type LocalStorageValues = {
homepage_collapse_state: string[];
homepage_activity_filter: SetTabType | null;
sqllab__is_autocomplete_enabled: boolean;
sqllab__is_detokenisation_enabled: boolean;
explore__data_table_time_formatted_columns: Record<string, string[]>;
};

Expand Down
70 changes: 70 additions & 0 deletions superset/aric_detokeniser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations

import asyncio
import contextvars
import functools
from asyncio import events

from requests_futures.sessions import FuturesSession
import json
import logging
from pandas import DataFrame

from superset import app

config = app.config
logger = logging.getLogger(__name__)

session = FuturesSession()
session.headers.update({
'Access-Token': config['DETOKENISE_ACCESS_TOKEN'],
'Content-Type': 'text/plain; charset=utf-8'
})


def detokenise_json(df: DataFrame) -> DataFrame:
if df.dtype == 'object' and any(val.startswith('t:') for val in df):
data = json.dumps({"id": df.to_list()})
req = session.post(config['DETOKENISE_POST_URL'],
data=data)
return req.result().json()
return df


async def to_thread(func, /, *args, **kwargs):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propogated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)


async def detokenise_post_process(df: DataFrame) -> DataFrame:
data = await asyncio.gather(
*[to_thread(detokenise_json, df[col]) for col in df])
for count, col in enumerate(df):
df[col] = data[count]
return df
29 changes: 4 additions & 25 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
# pylint: disable=invalid-name
from __future__ import annotations

from requests_futures.sessions import FuturesSession
from multiprocessing import Pool
import re
import asyncio

import json
import logging
from datetime import datetime, timedelta
Expand All @@ -30,6 +29,7 @@
from pandas import DataFrame

from superset import app
from superset.aric_detokeniser import detokenise_post_process
from superset.common.chart_data import ChartDataResultType
from superset.exceptions import (
InvalidPostProcessingError,
Expand Down Expand Up @@ -57,12 +57,6 @@
config = app.config
logger = logging.getLogger(__name__)

session = FuturesSession()
session.headers.update({
'Access-Token': config['DETOKENISE_ACCESS_TOKEN'],
'Content-Type': 'text/plain; charset=utf-8'
})

# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
# https://github.com/python/mypy/issues/5288

Expand Down Expand Up @@ -412,21 +406,6 @@ def cache_key(self, **extra: Any) -> str:

return md5_sha_from_dict(cache_dict, default=json_int_dttm_ser, ignore_nan=True)

@staticmethod
def detokenise(token: str) -> str:
if re.search(r't:(.*)', token):
req = session.post(config['DETOKENISE_POST_URL'], data='\"'+token+'\"')
return str(req.result().text)
return token

@classmethod
def detokeniser(cls, df: DataFrame) -> DataFrame:
if df.dtype == 'object':
p = Pool()
df = p.map(cls.detokenise, df)
p.close()
return df

def exec_post_processing(self, df: DataFrame) -> DataFrame:
"""
Perform post processing operations on DataFrame.
Expand All @@ -440,7 +419,7 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame:
logger.debug("post_processing: \n %s", pformat(self.post_processing))

if self.detoken_select:
df = df.apply(self.detokeniser)
df = asyncio.run(detokenise_post_process(df))

for post_process in self.post_processing:
operation = post_process.get("operation")
Expand Down
15 changes: 12 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
import dataclasses
import logging
import uuid
Expand All @@ -32,6 +33,7 @@
from sqlalchemy.orm import Session

from superset import app, results_backend, results_backend_use_msgpack, security_manager
from superset.aric_detokeniser import detokenise_post_process
from superset.common.db_query_status import QueryStatus
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
Expand Down Expand Up @@ -315,6 +317,7 @@ def _serialize_and_expand_data(
db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False,
expand_data: bool = False,
detokenisation: Optional[bool] = False,
) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
selected_columns = result_set.columns
all_columns: List[Any]
Expand All @@ -324,6 +327,10 @@ def _serialize_and_expand_data(
with stats_timing(
"sqllab.query.results_backend_pa_serialization", stats_logger
):
if detokenisation:
result_set.table = pa.Table.from_pandas(
asyncio.run(detokenise_post_process(result_set.to_pandas_df()))
)
data = (
pa.default_serialization_context()
.serialize(result_set.pa_table)
Expand All @@ -335,6 +342,8 @@ def _serialize_and_expand_data(
all_columns, expanded_columns = (selected_columns, [])
else:
df = result_set.to_pandas_df()
if detokenisation:
df = asyncio.run(detokenise_post_process(df))
data = df_to_records(df) or []

if expand_data:
Expand Down Expand Up @@ -505,10 +514,10 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
latest_partition=False,
)
query.end_time = now_as_float()

detokenisation = log_params.get('detokenisation')
use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
result_set, db_engine_spec, use_arrow_data, expand_data
result_set, db_engine_spec, use_arrow_data, expand_data, detokenisation
)

# TODO: data should be saved separately from metadata (likely in Parquet)
Expand Down Expand Up @@ -560,7 +569,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
all_columns,
expanded_columns,
) = _serialize_and_expand_data(
result_set, db_engine_spec, False, expand_data
result_set, db_engine_spec, False, expand_data, detokenisation
)
payload.update(
{
Expand Down
5 changes: 4 additions & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,7 +2567,10 @@ def validate_sql_json(
def sql_json(self) -> FlaskResponse:
try:
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT")),
"detokenisation": cast(bool, bool(request.headers.get(
"DETOKENISATION"
)))
}
execution_context = SqlJsonExecutionContext(request.json)
command = self._create_sql_json_command(execution_context, log_params)
Expand Down

0 comments on commit 464dcef

Please sign in to comment.