Skip to content

Commit

Permalink
fix: Time shifts calculation for ECharts plugins (#28432)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina authored May 15, 2024
1 parent f0b7b95 commit 821c7d7
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ export { getMetricOffsetsMap } from './getMetricOffsetsMap';
export { isTimeComparison } from './isTimeComparison';
export { isDerivedSeries } from './isDerivedSeries';
export { extractExtraMetrics } from './extractExtraMetrics';
export { getOriginalSeries, hasTimeOffset } from './timeOffset';
export { getOriginalSeries, hasTimeOffset, getTimeOffset } from './timeOffset';
export { TIME_COMPARISON_SEPARATOR } from './constants';
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,23 @@
import { JsonObject } from '@superset-ui/core';
import { isString } from 'lodash';

export const getTimeOffset = (
series: JsonObject,
timeCompare: string[],
): string | undefined =>
timeCompare.find(
timeOffset =>
// offset is represented as <offset>, group by list
series.name.includes(`${timeOffset},`) ||
// offset is represented as <metric>__<offset>
series.name.includes(`__${timeOffset}`),
);

export const hasTimeOffset = (
series: JsonObject,
timeCompare: string[],
): boolean =>
isString(series.name)
? !!timeCompare.find(
timeOffset =>
// offset is represented as <offset>, group by list
series.name.includes(`${timeOffset},`) ||
// offset is represented as <metric>__<offset>
series.name.includes(`__${timeOffset}`),
)
: false;
isString(series.name) ? !!getTimeOffset(series, timeCompare) : false;

export const getOriginalSeries = (
seriesName: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ import {
extractExtraMetrics,
getOriginalSeries,
isDerivedSeries,
getTimeOffset,
} from '@superset-ui/chart-controls';
import { EChartsCoreOption, SeriesOption } from 'echarts';
import { ZRLineType } from 'echarts/types/src/util/types';
import { LineStyleOption } from 'echarts/types/src/util/types';
import {
EchartsTimeseriesChartProps,
EchartsTimeseriesFormData,
Expand Down Expand Up @@ -273,10 +274,22 @@ export default function transformProps(
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
const inverted = invert(verboseMap);

const offsetLineWidths = {};

rawSeries.forEach(entry => {
const lineStyle = isDerivedSeries(entry, chartProps.rawFormData)
? { type: 'dashed' as ZRLineType }
: {};
const derivedSeries = isDerivedSeries(entry, chartProps.rawFormData);
const lineStyle: LineStyleOption = {};
if (derivedSeries) {
const offset = getTimeOffset(
entry,
ensureIsArray(chartProps.rawFormData?.time_compare),
)!;
if (!offsetLineWidths[offset]) {
offsetLineWidths[offset] = Object.keys(offsetLineWidths).length + 1;
}
lineStyle.type = 'dashed';
lineStyle.width = offsetLineWidths[offset];
}

const entryName = String(entry.name || '');
const seriesName = inverted[entryName] || entryName;
Expand All @@ -288,6 +301,7 @@ export default function transformProps(
colorScaleKey,
{
area,
connectNulls: derivedSeries,
filterState,
seriesContexts,
markerEnabled,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ export function transformSeries(
colorScaleKey: string,
opts: {
area?: boolean;
connectNulls?: boolean;
filterState?: FilterState;
seriesContexts?: { [key: string]: ForecastSeriesEnum[] };
markerEnabled?: boolean;
Expand Down Expand Up @@ -170,6 +171,7 @@ export function transformSeries(
const { name } = series;
const {
area,
connectNulls,
filterState,
seriesContexts = {},
markerEnabled,
Expand Down Expand Up @@ -268,6 +270,7 @@ export function transformSeries(
: { ...opts.lineStyle, opacity };
return {
...series,
connectNulls,
queryIndex,
yAxisIndex,
name: forecastSeries.name,
Expand Down
186 changes: 119 additions & 67 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)

# Temporary column used for joining aggregated offset results
AGGREGATED_JOIN_COLUMN = "__aggregated_join_column"
# Offset join column suffix used for joining offset results
OFFSET_JOIN_COLUMN_SUFFIX = "__offset_join_column_"

# This only includes time grains that may influence
# the temporal column used for joining offset results.
Expand Down Expand Up @@ -339,19 +339,31 @@ def get_time_grain(query_object: QueryObject) -> Any | None:

return query_object.extras.get("time_grain_sqla")

def add_aggregated_join_column(
# pylint: disable=too-many-arguments
def add_offset_join_column(
self,
df: pd.DataFrame,
name: str,
time_grain: str,
time_offset: str | None = None,
join_column_producer: Any = None,
) -> None:
"""
Adds an offset join column to the provided DataFrame.
The function modifies the DataFrame in-place.
:param df: pandas DataFrame to which the offset join column will be added.
:param name: The name of the new column to be added.
:param time_grain: The time grain used to calculate the new column.
:param time_offset: The time offset used to calculate the new column.
:param join_column_producer: A function to generate the join column.
"""
if join_column_producer:
df[AGGREGATED_JOIN_COLUMN] = df.apply(
lambda row: join_column_producer(row, 0), axis=1
)
df[name] = df.apply(lambda row: join_column_producer(row, 0), axis=1)
else:
df[AGGREGATED_JOIN_COLUMN] = df.apply(
lambda row: self.get_aggregated_join_column(row, 0, time_grain),
df[name] = df.apply(
lambda row: self.generate_join_column(row, 0, time_grain, time_offset),
axis=1,
)

Expand All @@ -365,7 +377,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
query_object_clone = copy.copy(query_object)
queries: list[str] = []
cache_keys: list[str | None] = []
offset_dfs: list[pd.DataFrame] = []
offset_dfs: dict[str, pd.DataFrame] = {}

outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
if not outer_from_dttm or not outer_to_dttm:
Expand All @@ -376,28 +388,17 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
)
)

columns = df.columns
time_grain = self.get_time_grain(query_object)

if not time_grain:
raise QueryObjectValidationError(
_("Time Grain must be specified when using Time Shift.")
)

join_column_producer = config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get(
time_grain
)
use_aggregated_join_column = (
join_column_producer or time_grain in AGGREGATED_JOIN_GRAINS
)
if use_aggregated_join_column:
self.add_aggregated_join_column(df, time_grain, join_column_producer)
# skips the first column which is the temporal column
# because we'll use the aggregated join columns instead
columns = df.columns[1:]

metric_names = get_metric_names(query_object.metrics)
join_keys = [col for col in columns if col not in metric_names]

# use columns that are not metrics as join keys
join_keys = [col for col in df.columns if col not in metric_names]

for offset in query_object.time_offsets:
try:
Expand Down Expand Up @@ -443,7 +444,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
)
# whether hit on the cache
if cache.is_loaded:
offset_dfs.append(cache.df)
offset_dfs[offset] = cache.df
queries.append(cache.query)
cache_keys.append(cache_key)
continue
Expand Down Expand Up @@ -497,16 +498,6 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
)
)

# modifies temporal column using offset
offset_metrics_df[index] = offset_metrics_df[index] - DateOffset(
**normalize_time_delta(offset)
)

if use_aggregated_join_column:
self.add_aggregated_join_column(
offset_metrics_df, time_grain, join_column_producer
)

# cache df and query
value = {
"df": offset_metrics_df,
Expand All @@ -519,51 +510,112 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
datasource_uid=query_context.datasource.uid,
region=CacheRegion.DATA,
)
offset_dfs.append(offset_metrics_df)
offset_dfs[offset] = offset_metrics_df

if offset_dfs:
# iterate on offset_dfs, left join each with df
for offset_df in offset_dfs:
df = dataframe_utils.left_join_df(
left_df=df,
right_df=offset_df,
join_keys=join_keys,
rsuffix=R_SUFFIX,
)
df = self.join_offset_dfs(
df,
offset_dfs,
time_grain,
join_keys,
)

return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)

def join_offset_dfs(
self,
df: pd.DataFrame,
offset_dfs: dict[str, pd.DataFrame],
time_grain: str,
join_keys: list[str],
) -> pd.DataFrame:
"""
Join offset DataFrames with the main DataFrame.
# removes columns used for join
df.drop(
list(df.filter(regex=f"{AGGREGATED_JOIN_COLUMN}|{R_SUFFIX}")),
axis=1,
inplace=True,
:param df: The main DataFrame.
:param offset_dfs: A list of offset DataFrames.
:param time_grain: The time grain used to calculate the temporal join key.
:param join_keys: The keys to join on.
"""
join_column_producer = config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get(
time_grain
)

return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)
# iterate on offset_dfs, left join each with df
for offset, offset_df in offset_dfs.items():
# defines a column name for the offset join column
column_name = OFFSET_JOIN_COLUMN_SUFFIX + offset

# add offset join column to df
self.add_offset_join_column(
df, column_name, time_grain, offset, join_column_producer
)

# add offset join column to offset_df
self.add_offset_join_column(
offset_df, column_name, time_grain, None, join_column_producer
)

# the temporal column is the first column in the join keys
# so we use the join column instead of the temporal column
actual_join_keys = [column_name, *join_keys[1:]]

# left join df with offset_df
df = dataframe_utils.left_join_df(
left_df=df,
right_df=offset_df,
join_keys=actual_join_keys,
rsuffix=R_SUFFIX,
)

# move the temporal column to the first column in df
col = df.pop(join_keys[0])
df.insert(0, col.name, col)

# removes columns created only for join purposes
df.drop(
list(df.filter(regex=f"{OFFSET_JOIN_COLUMN_SUFFIX}|{R_SUFFIX}")),
axis=1,
inplace=True,
)
return df

@staticmethod
def get_aggregated_join_column(
row: pd.Series, column_index: int, time_grain: str
def generate_join_column(
row: pd.Series,
column_index: int,
time_grain: str,
time_offset: str | None = None,
) -> str:
if time_grain in (
TimeGrain.WEEK_STARTING_SUNDAY,
TimeGrain.WEEK_ENDING_SATURDAY,
):
return row[column_index].strftime("%Y-W%U")
value = row[column_index]

if time_grain in (
TimeGrain.WEEK,
TimeGrain.WEEK_STARTING_MONDAY,
TimeGrain.WEEK_ENDING_SUNDAY,
):
return row[column_index].strftime("%Y-W%W")
if hasattr(value, "strftime"):
if time_offset:
value = value + DateOffset(**normalize_time_delta(time_offset))

if time_grain in (
TimeGrain.WEEK_STARTING_SUNDAY,
TimeGrain.WEEK_ENDING_SATURDAY,
):
return value.strftime("%Y-W%U")

if time_grain in (
TimeGrain.WEEK,
TimeGrain.WEEK_STARTING_MONDAY,
TimeGrain.WEEK_ENDING_SUNDAY,
):
return value.strftime("%Y-W%W")

if time_grain == TimeGrain.MONTH:
return value.strftime("%Y-%m")

if time_grain == TimeGrain.MONTH:
return row[column_index].strftime("%Y-%m")
if time_grain == TimeGrain.QUARTER:
return value.strftime("%Y-Q") + str(value.quarter)

if time_grain == TimeGrain.QUARTER:
return row[column_index].strftime("%Y-Q") + str(row[column_index].quarter)
if time_grain == TimeGrain.YEAR:
return value.strftime("%Y")

return row[column_index].strftime("%Y")
return str(value)

def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
if self._query_context.result_format in ChartDataResultFormat.table_like():
Expand Down
Loading

0 comments on commit 821c7d7

Please sign in to comment.