From 821c7d7f2c430c4a4294883a66128ba98fd949c5 Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" <70410625+michael-s-molina@users.noreply.github.com> Date: Wed, 15 May 2024 08:11:52 -0300 Subject: [PATCH] fix: Time shifts calculation for ECharts plugins (#28432) --- .../src/operators/utils/index.ts | 2 +- .../src/operators/utils/timeOffset.ts | 22 ++- .../src/Timeseries/transformProps.ts | 22 ++- .../src/Timeseries/transformers.ts | 3 + superset/common/query_context_processor.py | 186 ++++++++++------- .../integration_tests/query_context_tests.py | 11 +- .../common/test_get_aggregated_join_column.py | 77 -------- tests/unit_tests/common/test_time_shifts.py | 187 ++++++++++++++++++ 8 files changed, 347 insertions(+), 163 deletions(-) delete mode 100644 tests/unit_tests/common/test_get_aggregated_join_column.py create mode 100644 tests/unit_tests/common/test_time_shifts.py diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/index.ts b/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/index.ts index 1d91a6965f52b..f461db0c5a637 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/index.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/index.ts @@ -21,5 +21,5 @@ export { getMetricOffsetsMap } from './getMetricOffsetsMap'; export { isTimeComparison } from './isTimeComparison'; export { isDerivedSeries } from './isDerivedSeries'; export { extractExtraMetrics } from './extractExtraMetrics'; -export { getOriginalSeries, hasTimeOffset } from './timeOffset'; +export { getOriginalSeries, hasTimeOffset, getTimeOffset } from './timeOffset'; export { TIME_COMPARISON_SEPARATOR } from './constants'; diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/timeOffset.ts b/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/timeOffset.ts index b11572c6dda69..8a7d9a964f8b8 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/timeOffset.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/operators/utils/timeOffset.ts @@ -20,19 +20,23 @@ import { JsonObject } from '@superset-ui/core'; import { isString } from 'lodash'; +export const getTimeOffset = ( + series: JsonObject, + timeCompare: string[], +): string | undefined => + timeCompare.find( + timeOffset => + // offset is represented as , group by list + series.name.includes(`${timeOffset},`) || + // offset is represented as __ + series.name.includes(`__${timeOffset}`), + ); + export const hasTimeOffset = ( series: JsonObject, timeCompare: string[], ): boolean => - isString(series.name) - ? !!timeCompare.find( - timeOffset => - // offset is represented as , group by list - series.name.includes(`${timeOffset},`) || - // offset is represented as __ - series.name.includes(`__${timeOffset}`), - ) - : false; + isString(series.name) ? !!getTimeOffset(series, timeCompare) : false; export const getOriginalSeries = ( seriesName: string, diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts index d3e7673ea738d..e63d67f8b610f 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts @@ -45,9 +45,10 @@ import { extractExtraMetrics, getOriginalSeries, isDerivedSeries, + getTimeOffset, } from '@superset-ui/chart-controls'; import { EChartsCoreOption, SeriesOption } from 'echarts'; -import { ZRLineType } from 'echarts/types/src/util/types'; +import { LineStyleOption } from 'echarts/types/src/util/types'; import { EchartsTimeseriesChartProps, EchartsTimeseriesFormData, @@ -273,10 +274,22 @@ export default function transformProps( const array = ensureIsArray(chartProps.rawFormData?.time_compare); const inverted = invert(verboseMap); + const offsetLineWidths = {}; + rawSeries.forEach(entry => { - const lineStyle = isDerivedSeries(entry, chartProps.rawFormData) - ? { type: 'dashed' as ZRLineType } - : {}; + const derivedSeries = isDerivedSeries(entry, chartProps.rawFormData); + const lineStyle: LineStyleOption = {}; + if (derivedSeries) { + const offset = getTimeOffset( + entry, + ensureIsArray(chartProps.rawFormData?.time_compare), + )!; + if (!offsetLineWidths[offset]) { + offsetLineWidths[offset] = Object.keys(offsetLineWidths).length + 1; + } + lineStyle.type = 'dashed'; + lineStyle.width = offsetLineWidths[offset]; + } const entryName = String(entry.name || ''); const seriesName = inverted[entryName] || entryName; @@ -288,6 +301,7 @@ export default function transformProps( colorScaleKey, { area, + connectNulls: derivedSeries, filterState, seriesContexts, markerEnabled, diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformers.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformers.ts index 3b5cfd594a70b..baa109002cb8b 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformers.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformers.ts @@ -143,6 +143,7 @@ export function transformSeries( colorScaleKey: string, opts: { area?: boolean; + connectNulls?: boolean; filterState?: FilterState; seriesContexts?: { [key: string]: ForecastSeriesEnum[] }; markerEnabled?: boolean; @@ -170,6 +171,7 @@ export function transformSeries( const { name } = series; const { area, + connectNulls, filterState, seriesContexts = {}, markerEnabled, @@ -268,6 +270,7 @@ export function transformSeries( : { ...opts.lineStyle, opacity }; return { ...series, + connectNulls, queryIndex, yAxisIndex, name: forecastSeries.name, diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 55c80386a316e..c47e295e96c52 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -77,8 +77,8 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) -# Temporary column used for joining aggregated offset results -AGGREGATED_JOIN_COLUMN = "__aggregated_join_column" +# Offset join column suffix used for joining offset results +OFFSET_JOIN_COLUMN_SUFFIX = "__offset_join_column_" # This only includes time grains that may influence # the temporal column used for joining offset results. @@ -339,19 +339,31 @@ def get_time_grain(query_object: QueryObject) -> Any | None: return query_object.extras.get("time_grain_sqla") - def add_aggregated_join_column( + # pylint: disable=too-many-arguments + def add_offset_join_column( self, df: pd.DataFrame, + name: str, time_grain: str, + time_offset: str | None = None, join_column_producer: Any = None, ) -> None: + """ + Adds an offset join column to the provided DataFrame. + + The function modifies the DataFrame in-place. + + :param df: pandas DataFrame to which the offset join column will be added. + :param name: The name of the new column to be added. + :param time_grain: The time grain used to calculate the new column. + :param time_offset: The time offset used to calculate the new column. + :param join_column_producer: A function to generate the join column. + """ if join_column_producer: - df[AGGREGATED_JOIN_COLUMN] = df.apply( - lambda row: join_column_producer(row, 0), axis=1 - ) + df[name] = df.apply(lambda row: join_column_producer(row, 0), axis=1) else: - df[AGGREGATED_JOIN_COLUMN] = df.apply( - lambda row: self.get_aggregated_join_column(row, 0, time_grain), + df[name] = df.apply( + lambda row: self.generate_join_column(row, 0, time_grain, time_offset), axis=1, ) @@ -365,7 +377,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme query_object_clone = copy.copy(query_object) queries: list[str] = [] cache_keys: list[str | None] = [] - offset_dfs: list[pd.DataFrame] = [] + offset_dfs: dict[str, pd.DataFrame] = {} outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object) if not outer_from_dttm or not outer_to_dttm: @@ -376,7 +388,6 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme ) ) - columns = df.columns time_grain = self.get_time_grain(query_object) if not time_grain: @@ -384,20 +395,10 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme _("Time Grain must be specified when using Time Shift.") ) - join_column_producer = config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get( - time_grain - ) - use_aggregated_join_column = ( - join_column_producer or time_grain in AGGREGATED_JOIN_GRAINS - ) - if use_aggregated_join_column: - self.add_aggregated_join_column(df, time_grain, join_column_producer) - # skips the first column which is the temporal column - # because we'll use the aggregated join columns instead - columns = df.columns[1:] - metric_names = get_metric_names(query_object.metrics) - join_keys = [col for col in columns if col not in metric_names] + + # use columns that are not metrics as join keys + join_keys = [col for col in df.columns if col not in metric_names] for offset in query_object.time_offsets: try: @@ -443,7 +444,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme ) # whether hit on the cache if cache.is_loaded: - offset_dfs.append(cache.df) + offset_dfs[offset] = cache.df queries.append(cache.query) cache_keys.append(cache_key) continue @@ -497,16 +498,6 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme ) ) - # modifies temporal column using offset - offset_metrics_df[index] = offset_metrics_df[index] - DateOffset( - **normalize_time_delta(offset) - ) - - if use_aggregated_join_column: - self.add_aggregated_join_column( - offset_metrics_df, time_grain, join_column_producer - ) - # cache df and query value = { "df": offset_metrics_df, @@ -519,51 +510,112 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme datasource_uid=query_context.datasource.uid, region=CacheRegion.DATA, ) - offset_dfs.append(offset_metrics_df) + offset_dfs[offset] = offset_metrics_df if offset_dfs: - # iterate on offset_dfs, left join each with df - for offset_df in offset_dfs: - df = dataframe_utils.left_join_df( - left_df=df, - right_df=offset_df, - join_keys=join_keys, - rsuffix=R_SUFFIX, - ) + df = self.join_offset_dfs( + df, + offset_dfs, + time_grain, + join_keys, + ) + + return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) + + def join_offset_dfs( + self, + df: pd.DataFrame, + offset_dfs: dict[str, pd.DataFrame], + time_grain: str, + join_keys: list[str], + ) -> pd.DataFrame: + """ + Join offset DataFrames with the main DataFrame. - # removes columns used for join - df.drop( - list(df.filter(regex=f"{AGGREGATED_JOIN_COLUMN}|{R_SUFFIX}")), - axis=1, - inplace=True, + :param df: The main DataFrame. + :param offset_dfs: A list of offset DataFrames. + :param time_grain: The time grain used to calculate the temporal join key. + :param join_keys: The keys to join on. + """ + join_column_producer = config["TIME_GRAIN_JOIN_COLUMN_PRODUCERS"].get( + time_grain ) - return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys) + # iterate on offset_dfs, left join each with df + for offset, offset_df in offset_dfs.items(): + # defines a column name for the offset join column + column_name = OFFSET_JOIN_COLUMN_SUFFIX + offset + + # add offset join column to df + self.add_offset_join_column( + df, column_name, time_grain, offset, join_column_producer + ) + + # add offset join column to offset_df + self.add_offset_join_column( + offset_df, column_name, time_grain, None, join_column_producer + ) + + # the temporal column is the first column in the join keys + # so we use the join column instead of the temporal column + actual_join_keys = [column_name, *join_keys[1:]] + + # left join df with offset_df + df = dataframe_utils.left_join_df( + left_df=df, + right_df=offset_df, + join_keys=actual_join_keys, + rsuffix=R_SUFFIX, + ) + + # move the temporal column to the first column in df + col = df.pop(join_keys[0]) + df.insert(0, col.name, col) + + # removes columns created only for join purposes + df.drop( + list(df.filter(regex=f"{OFFSET_JOIN_COLUMN_SUFFIX}|{R_SUFFIX}")), + axis=1, + inplace=True, + ) + return df @staticmethod - def get_aggregated_join_column( - row: pd.Series, column_index: int, time_grain: str + def generate_join_column( + row: pd.Series, + column_index: int, + time_grain: str, + time_offset: str | None = None, ) -> str: - if time_grain in ( - TimeGrain.WEEK_STARTING_SUNDAY, - TimeGrain.WEEK_ENDING_SATURDAY, - ): - return row[column_index].strftime("%Y-W%U") + value = row[column_index] - if time_grain in ( - TimeGrain.WEEK, - TimeGrain.WEEK_STARTING_MONDAY, - TimeGrain.WEEK_ENDING_SUNDAY, - ): - return row[column_index].strftime("%Y-W%W") + if hasattr(value, "strftime"): + if time_offset: + value = value + DateOffset(**normalize_time_delta(time_offset)) + + if time_grain in ( + TimeGrain.WEEK_STARTING_SUNDAY, + TimeGrain.WEEK_ENDING_SATURDAY, + ): + return value.strftime("%Y-W%U") + + if time_grain in ( + TimeGrain.WEEK, + TimeGrain.WEEK_STARTING_MONDAY, + TimeGrain.WEEK_ENDING_SUNDAY, + ): + return value.strftime("%Y-W%W") + + if time_grain == TimeGrain.MONTH: + return value.strftime("%Y-%m") - if time_grain == TimeGrain.MONTH: - return row[column_index].strftime("%Y-%m") + if time_grain == TimeGrain.QUARTER: + return value.strftime("%Y-Q") + str(value.quarter) - if time_grain == TimeGrain.QUARTER: - return row[column_index].strftime("%Y-Q") + str(row[column_index].quarter) + if time_grain == TimeGrain.YEAR: + return value.strftime("%Y") - return row[column_index].strftime("%Y") + return str(value) def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]: if self._query_context.result_format in ChartDataResultFormat.table_like(): diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 3a14c37a43ef1..9c18b5e07c9d8 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -554,9 +554,9 @@ def test_processing_time_offsets_cache(self): query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] # query without cache - query_context.processing_time_offsets(df, query_object) + query_context.processing_time_offsets(df.copy(), query_object) # query with cache - rv = query_context.processing_time_offsets(df, query_object) + rv = query_context.processing_time_offsets(df.copy(), query_object) cache_keys = rv["cache_keys"] cache_keys__1_year_ago = cache_keys[0] cache_keys__1_year_later = cache_keys[1] @@ -568,7 +568,7 @@ def test_processing_time_offsets_cache(self): payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - rv = query_context.processing_time_offsets(df, query_object) + rv = query_context.processing_time_offsets(df.copy(), query_object) cache_keys = rv["cache_keys"] self.assertEqual(cache_keys__1_year_ago, cache_keys[1]) self.assertEqual(cache_keys__1_year_later, cache_keys[0]) @@ -578,10 +578,11 @@ def test_processing_time_offsets_cache(self): query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] rv = query_context.processing_time_offsets( - df, + df.copy(), query_object, ) - self.assertIs(rv["df"], df) + + self.assertEqual(rv["df"].shape, df.shape) self.assertEqual(rv["queries"], []) self.assertEqual(rv["cache_keys"], []) diff --git a/tests/unit_tests/common/test_get_aggregated_join_column.py b/tests/unit_tests/common/test_get_aggregated_join_column.py deleted file mode 100644 index de0b6b92b2850..0000000000000 --- a/tests/unit_tests/common/test_get_aggregated_join_column.py +++ /dev/null @@ -1,77 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from pandas import DataFrame, Series, Timestamp -from pandas.testing import assert_frame_equal -from pytest import fixture, mark - -from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType -from superset.common.query_context import QueryContext -from superset.common.query_context_processor import ( - AGGREGATED_JOIN_COLUMN, - QueryContextProcessor, -) -from superset.connectors.sqla.models import BaseDatasource -from superset.constants import TimeGrain - -query_context_processor = QueryContextProcessor( - QueryContext( - datasource=BaseDatasource(), - queries=[], - result_type=ChartDataResultType.COLUMNS, - form_data={}, - slice_=None, - result_format=ChartDataResultFormat.CSV, - cache_values={}, - ) -) - - -@fixture -def make_join_column_producer(): - def join_column_producer(row: Series, column_index: int) -> str: - return "CUSTOM_FORMAT" - - return join_column_producer - - -@mark.parametrize( - ("time_grain", "expected"), - [ - (TimeGrain.WEEK, "2020-W01"), - (TimeGrain.MONTH, "2020-01"), - (TimeGrain.QUARTER, "2020-Q1"), - (TimeGrain.YEAR, "2020"), - ], -) -def test_aggregated_join_column(time_grain: str, expected: str): - df = DataFrame({"ds": [Timestamp("2020-01-07")]}) - query_context_processor.add_aggregated_join_column(df, time_grain) - result = DataFrame( - {"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: [expected]} - ) - assert_frame_equal(df, result) - - -def test_aggregated_join_column_producer(make_join_column_producer): - df = DataFrame({"ds": [Timestamp("2020-01-07")]}) - query_context_processor.add_aggregated_join_column( - df, TimeGrain.YEAR, make_join_column_producer - ) - result = DataFrame( - {"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: ["CUSTOM_FORMAT"]} - ) - assert_frame_equal(df, result) diff --git a/tests/unit_tests/common/test_time_shifts.py b/tests/unit_tests/common/test_time_shifts.py new file mode 100644 index 0000000000000..3f25236a768bc --- /dev/null +++ b/tests/unit_tests/common/test_time_shifts.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from pandas import DataFrame, Series, Timestamp +from pandas.testing import assert_frame_equal +from pytest import fixture, mark + +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.common.query_context import QueryContext +from superset.common.query_context_processor import QueryContextProcessor +from superset.connectors.sqla.models import BaseDatasource +from superset.constants import TimeGrain + +query_context_processor = QueryContextProcessor( + QueryContext( + datasource=BaseDatasource(), + queries=[], + result_type=ChartDataResultType.COLUMNS, + form_data={}, + slice_=None, + result_format=ChartDataResultFormat.CSV, + cache_values={}, + ) +) + + +@fixture +def make_join_column_producer(): + def join_column_producer(row: Series, column_index: int) -> str: + return "CUSTOM_FORMAT" + + return join_column_producer + + +@mark.parametrize( + ("time_grain", "expected"), + [ + (TimeGrain.WEEK, "2020-W01"), + (TimeGrain.MONTH, "2020-01"), + (TimeGrain.QUARTER, "2020-Q1"), + (TimeGrain.YEAR, "2020"), + ], +) +def test_join_column(time_grain: str, expected: str): + df = DataFrame({"ds": [Timestamp("2020-01-07")]}) + column_name = "join_column" + query_context_processor.add_offset_join_column(df, column_name, time_grain) + result = DataFrame({"ds": [Timestamp("2020-01-07")], column_name: [expected]}) + assert_frame_equal(df, result) + + +def test_join_column_producer(make_join_column_producer): + df = DataFrame({"ds": [Timestamp("2020-01-07")]}) + column_name = "join_column" + query_context_processor.add_offset_join_column( + df, column_name, TimeGrain.YEAR, None, make_join_column_producer + ) + result = DataFrame( + {"ds": [Timestamp("2020-01-07")], column_name: ["CUSTOM_FORMAT"]} + ) + assert_frame_equal(df, result) + + +def test_join_offset_dfs_no_offsets(): + df = DataFrame({"A": ["2021-01-01", "2021-02-01", "2021-03-01"]}) + offset_dfs = {} + time_grain = "YEAR" + join_keys = ["A"] + + result = query_context_processor.join_offset_dfs( + df, offset_dfs, time_grain, join_keys + ) + + assert_frame_equal(df, result) + + +def test_join_offset_dfs_with_offsets(): + df = DataFrame({"A": ["2021-01-01", "2021-02-01", "2021-03-01"]}) + offset_df = DataFrame( + {"A": ["2021-02-01", "2021-03-01", "2021-04-01"], "B": [5, 6, 7]} + ) + offset_dfs = {"1_YEAR": offset_df} + time_grain = "YEAR" + join_keys = ["A"] + + expected = DataFrame( + {"A": ["2021-01-01", "2021-02-01", "2021-03-01"], "B": [None, 5, 6]} + ) + + result = query_context_processor.join_offset_dfs( + df, offset_dfs, time_grain, join_keys + ) + + assert_frame_equal(expected, result) + + +def test_join_offset_dfs_with_multiple_offsets(): + df = DataFrame({"A": ["2021-01-01", "2021-02-01", "2021-03-01"]}) + offset_df1 = DataFrame( + {"A": ["2021-02-01", "2021-03-01", "2021-04-01"], "B": [5, 6, 7]} + ) + offset_df2 = DataFrame( + {"A": ["2021-03-01", "2021-04-01", "2021-05-01"], "C": [8, 9, 10]} + ) + offset_dfs = {"1_YEAR": offset_df1, "2_YEAR": offset_df2} + time_grain = "YEAR" + join_keys = ["A"] + + expected = DataFrame( + { + "A": ["2021-01-01", "2021-02-01", "2021-03-01"], + "B": [None, 5, 6], + "C": [None, None, 8], + } + ) + + result = query_context_processor.join_offset_dfs( + df, offset_dfs, time_grain, join_keys + ) + + assert_frame_equal(expected, result) + + +def test_join_offset_dfs_with_month_granularity(): + df = DataFrame( + { + "A": [ + "2021-01-01", + "2021-01-15", + "2021-02-01", + "2021-02-15", + "2021-03-01", + "2021-03-15", + ], + "D": [1, 2, 3, 4, 5, 6], + } + ) + offset_df = DataFrame( + { + "A": [ + "2021-02-01", + "2021-02-15", + "2021-03-01", + "2021-03-15", + "2021-04-01", + "2021-04-15", + ], + "B": [5, 6, 7, 8, 9, 10], + } + ) + offset_dfs = {"1_MONTH": offset_df} + time_grain = "MONTH" + join_keys = ["A"] + + expected = DataFrame( + { + "A": [ + "2021-01-01", + "2021-01-15", + "2021-02-01", + "2021-02-15", + "2021-03-01", + "2021-03-15", + ], + "D": [1, 2, 3, 4, 5, 6], + "B": [None, None, 5, 6, 7, 8], + } + ) + + result = query_context_processor.join_offset_dfs( + df, offset_dfs, time_grain, join_keys + ) + + assert_frame_equal(expected, result)