diff --git a/dev_scripts/checks.py b/dev_scripts/checks.py deleted file mode 100644 index 265a864..0000000 --- a/dev_scripts/checks.py +++ /dev/null @@ -1,42 +0,0 @@ -import pandas as pd -import portfolyo as pf -from portfolyo.core.shared import concat - - -def get_idx( - startdate: str, starttime: str, tz: str, freq: str, enddate: str -) -> pd.DatetimeIndex: - # Empty index. - if startdate is None: - return pd.DatetimeIndex([], freq=freq, tz=tz) - # Normal index. - ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) - ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) - return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") - - -index = pd.date_range("2020", "2024", freq="QS", inclusive="left") -# index2 = pd.date_range("2023", "2025", freq="QS", inclusive="left") -# pfl = pf.dev.get_flatpfline(index) -# pfl2 = pf.dev.get_flatpfline(index2) -# print(pfl) -# print(pfl2) - -# pfs = pf.dev.get_pfstate(index) - -# pfs2 = pf.dev.get_pfstate(index2) -# pfl3 = concat.general(pfl, pfl2) -# print(pfl3) - -# print(index) -# print(index2) - -whole_pfl = pf.dev.get_nestedpfline(index) -pfl_a = whole_pfl.slice[:"2021"] - -pfl_b = whole_pfl.slice["2021":"2022"] -pfl_c = whole_pfl.slice["2022":] -result = concat.concat_pflines(pfl_a, pfl_b, pfl_c) -result2 = concat.concat_pflines(pfl_b, pfl_c, pfl_a) -print(result) -print(result2) diff --git a/docs/core/pfline.rst b/docs/core/pfline.rst index 2c907cc..e4ccb54 100644 --- a/docs/core/pfline.rst +++ b/docs/core/pfline.rst @@ -270,6 +270,7 @@ Another slicing method is implemented with the ``.slice[]`` property. The improv + Concatenation ============= diff --git a/docs/core/toplevel.rst b/docs/core/toplevel.rst index 4b2c891..7bd08e7 100644 --- a/docs/core/toplevel.rst +++ b/docs/core/toplevel.rst @@ -23,8 +23,11 @@ Work on pandas objects Work on portfolyo objects ------------------------- -* ``portfolyo.concat()`` Concatenates PfLines into one PfLine. +* ``portfolyo.concat()`` Concatenates PfLines (or PfStates) into one PfLine (or PfState). * ``portfolyo.plot_pfstates()`` Plots several PfStates in one figure. +* ``portfolyo.intersection()`` Intersect several dataframes and/or series and/or Pflines and/or PfStates. + + diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt deleted file mode 100644 index c6f63b7..0000000 --- a/docs/requirements-docs.txt +++ /dev/null @@ -1,14 +0,0 @@ -sphinx -sphinx-autobuild -# These are dependencies of various sphinx extensions for documentation. -ipython==8.12 # readthedocs needs python <=3.8, and this only works with ipython <=8.12 -matplotlib -numpydoc -sphinx-copybutton -sphinx-exec-code -sphinx_rtd_theme - -insegel - -nbsphinx -pandoc \ No newline at end of file diff --git a/docs/savefig/fig_hedge.png b/docs/savefig/fig_hedge.png index 2fea7c3..13bb208 100644 Binary files a/docs/savefig/fig_hedge.png and b/docs/savefig/fig_hedge.png differ diff --git a/docs/savefig/fig_offtake.png b/docs/savefig/fig_offtake.png index db42f87..160a93e 100644 Binary files a/docs/savefig/fig_offtake.png and b/docs/savefig/fig_offtake.png differ diff --git a/portfolyo/__init__.py b/portfolyo/__init__.py index 8eef784..1e48760 100644 --- a/portfolyo/__init__.py +++ b/portfolyo/__init__.py @@ -3,10 +3,11 @@ from . import _version, dev, testing, tools from .core import extendpandas # extend functionalty of pandas from .core import suppresswarnings +from .tools2.plot import plot_pfstates from .core.pfline import Kind, PfLine, Structure, create from .core.pfstate import PfState -from .core.shared.concat import general as concat -from .core.shared.plot import plot_pfstates +from .tools2.concat import general as concat +from .tools2.plot import plot_pfstates from .prices.hedge import hedge from .prices.utils import is_peak_hour from .tools.changefreq import averagable as asfreq_avg @@ -17,6 +18,8 @@ from .tools.tzone import force_agnostic, force_aware from .tools.unit import Q_, Unit, ureg from .tools.wavg import general as wavg +from .tools2.concat import general as concat +from .tools2.intersect import indexable as intersection # from .core.shared.concat import general as concat diff --git a/portfolyo/core/shared/plot.py b/portfolyo/core/shared/plot.py index 460cb67..a328d25 100644 --- a/portfolyo/core/shared/plot.py +++ b/portfolyo/core/shared/plot.py @@ -4,10 +4,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import matplotlib -import numpy as np + from matplotlib import pyplot as plt from ... import tools @@ -115,51 +115,6 @@ def plot(self: PfLine, cols: str = None) -> plt.Figure: class PfStatePlot: - # def plot_to_ax( - # self: PfState, ax: plt.Axes, line: str = "offtake", col: str = None, **kwargs - # ) -> None: - # """Plot a timeseries of a PfState in the portfolio state to a specific axes. - - # Parameters - # ---------- - # ax : plt.Axes - # The axes object to which to plot the timeseries. - # line : str, optional - # The pfline to plot. One of {'offtake' (default), 'sourced', 'unsourced', - # 'netposition', 'procurement', 'sourcedfraction'}. - # col : str, optional - # The column to plot. Default: plot volume `w` [MW] (if available) or else - # price `p` [Eur/MWh]. - # Any additional kwargs are passed to the pd.Series.plot function. - # """ - # if line == "offtake": - # how = DEFAULTHOW.get(col, "step") - # (-self.offtake).plot_to_ax(ax, col, how) - # ax.bar_label( - # ax.containers[0], label_type="edge", fmt="%,.0f".replace(",", " ") - # ) - - # elif line.endswith("sourcedfraction"): # (un)sourcedfraction - # fractions = getattr(self, line) - # vis.plot_timeseries(ax, fractions, how="bar", color="grey") - # ax.bar_label( - # ax.containers[0], - # label_type="edge", - # labels=fractions.apply("{:.0%}".format), - # ) # print labels on top of each bar - - # elif line == "sourced": - # self.sourced.plot_to_ax( - # ax, - # col, - # ) - # if col == "p": - - # vis.plot_timeseries(ax, self.unsourcedprice["p"], how="bar", alpha=0.0) - # ax.bar_label( - # ax.containers[0], label_type="center", fmt="%.2f" - # ) # print labels on top of each bar - def plot(self: PfState) -> plt.Figure: """Plot the portfolio state. @@ -225,137 +180,3 @@ def plot(self: PfState) -> plt.Figure: fig.tight_layout() return fig - - -def plot_pfstates(dic: Dict[str, PfState], freq: str = "MS") -> plt.Figure: - """Plot multiple PfState instances. - - Parameters - ---------- - dic : Dict[str, PfState] - Dictionary with PfState instances as values, and their names as the keys. - - Returns - ------- - plt.Figure - The figure object to which the instances were plotted. - """ - - gridspec = {"width_ratios": [0.3, 1, 1], "height_ratios": [4, 1] * len(dic)} - figsize = (15, 5 * len(dic)) - fig, axes = plt.subplots(len(dic) * 2, 3, gridspec_kw=gridspec, figsize=figsize) - axesgroups = axes.flatten().reshape((len(dic), 6)) - - # Share x axes. - sharex = axesgroups[:, (1, 2, 4)].flatten() - for ax1, ax2 in zip(sharex[1:], sharex[:-1]): - ax1.sharex(ax2) - # Share y axes. - sharey = axesgroups[:, 2] - for ax1, ax2 in zip(sharey[1:], sharey[:-1]): - ax1.sharey(ax2) - - # TODO: resample all to have same index (frequency and length). - - for i, ((pfname, pfs), axes) in enumerate(zip(dic.items(), axesgroups)): - # If freq is MS or longer: use categorical axes. Plot volumes in MWh. - # If freq is D or shorter: use time axes. Plot volumes in MW. - is_category = tools.freq.shortest(pfs.index.freq, "MS") == "MS" - - # Portfolio name. - axes[0].text( - 0, - 1, - pfname.replace(" ", "\n"), - fontsize=12, - fontweight="bold", - verticalalignment="top", - horizontalalignment="left", - ) - axes[0].axis("off") - - # Volumes. - if is_category: - s, kwargs = -1 * pfs.offtakevolume.q, defaultkwargs("q", is_category) - else: - s, kwargs = -1 * pfs.offtakevolume.w, defaultkwargs("w", is_category) - vis.plot_timeseries(axes[1], s, **kwargs) - - # Sourced fraction. - vis.plot_timeseries( - axes[2], pfs.sourcedfraction, **defaultkwargs("f", is_category) - ) - - # Empty. - axes[3].axis("off") - - # Procurement Price. - vis.plot_timeseries(axes[4], pfs.pnl_cost.p, **defaultkwargs("p", is_category)) - - # Empty. - axes[5].axis("off") - - # Tick formatting. - axes[2].yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(1.0)) - axes[1].yaxis.set_major_formatter( - matplotlib.ticker.FuncFormatter( - lambda x, p: "{:,.0f}".format(x).replace(",", " ") - ) - ) - - for a, ax in enumerate(axes): - if i == 0 and a in [1, 2]: - ax.xaxis.set_tick_params(labelbottom=False, labeltop=True, pad=25) - else: - ax.xaxis.set_tick_params(labelbottom=False, labeltop=False) - - if i == 0: - axes[1].set_title("Offtake Volume &\nprocurement price", y=1.27) - axes[2].set_title("Sourced fraction", y=1.27) - - return - draw_horizontal_lines(fig, axes) # draw horizontal lines between portfolios - - -def draw_horizontal_lines(fig, axes): - """Function to draw horizontal lines between multiple portfolios. - This function does not return anything, but tries to plot a 2D line after every 2 axes, eg. - after (0,2), (0,4),... beacuse each portfolio requires 2x4 axes in the fig (where rows=2, columns=4). - - Parameters - ---------- - fig : plt.subplots() - axes : plt.subplots() - """ - # rearange the axes for no overlap - fig.tight_layout() - - # Get the bounding boxes of the axes including text decorations - r = fig.canvas.get_renderer() - bboxes = np.array( - [ - ax.get_tightbbox(r).transformed(fig.transFigure.inverted()) - for ax in axes.flat - ], - matplotlib.transforms.Bbox, - ).reshape(axes.shape) - - """TO CORRECT: the horizontal line is not exactly in the middle of two graphs. - It is more inclined towards the second or next graph in the queue. - Each pftstate has 4x4 grid and this is plotted in the same graph, but as subgraphs. - """ - - # Get the minimum and maximum extent, get the coordinate half-way between those - ymax = ( - np.array(list(map(lambda b: b.y1, bboxes.flat))).reshape(axes.shape).max(axis=1) - ) - ymin = ( - np.array(list(map(lambda b: b.y0, bboxes.flat))).reshape(axes.shape).min(axis=1) - ) - ys = np.c_[ymax[2:-1:2], ymin[1:-2:2]].mean(axis=1) - ys = [ymax[0], *ys] - - # Draw a horizontal lines at those coordinates - for y in ys: - line = plt.Line2D([0, 1], [y, y], transform=fig.transFigure, color="black") - fig.add_artist(line) diff --git a/portfolyo/tools/intersect.py b/portfolyo/tools/intersect.py index fa940a2..d28cdac 100644 --- a/portfolyo/tools/intersect.py +++ b/portfolyo/tools/intersect.py @@ -1,6 +1,10 @@ -from typing import List, Union - +from typing import List, Union, Tuple import pandas as pd +from portfolyo import tools + +from portfolyo.tools.right import stamp +from portfolyo.tools.freq import longest, longer_or_shorter +from datetime import datetime def indices(*idxs: pd.DatetimeIndex) -> pd.DatetimeIndex: @@ -57,8 +61,138 @@ def indices(*idxs: pd.DatetimeIndex) -> pd.DatetimeIndex: return pd.DatetimeIndex(sorted(list(values)), freq=freq, name=name, tz=tz) +def indices_flex( + *idxs: pd.DatetimeIndex, + ignore_freq: bool = False, + ignore_tz: bool = False, + ignore_start_of_day: bool = False, +) -> Tuple[pd.DatetimeIndex]: + """Intersect several DatetimeIndices, but allow for more flexibility of ignoring + certain properties. + + Parameters + ---------- + *idxs : pd.DatetimeIndex + The indices to intersect. + ignore_freq: bool, optional (default: False) + If True, do the intersection even if the frequencies do not match; drop the + time periods that do not (fully) exist in either of the frames. + ignore_tz: bool, optional (default: False) + If True, ignore the timezones; perform the intersection using 'wall time'. + ignore_start_of_day: bool, optional (default: False) + If True, perform the intersection even if the frames have a different start-of-day. + The start-of-day of the original frames is preserved, even if the frequency is shorter + than daily. + + Returns + ------- + Tuple[pd.DatetimeIndex] + The intersection for each datetimeindex (in same order as input idxs). + + See also + -------- + indices + """ + if len(idxs) == 0: + raise ValueError("Must specify at least one index.") + + if len(idxs) == 1: + return idxs[0] + # convert tuple object into a list + idxs = list(idxs) + + # If we land here, we have at least 2 indices. + distinct_freqs = set([i.freq for i in idxs]) + if len(distinct_freqs) != 1 and ignore_freq is False: + raise ValueError(f"Indices must have equal frequencies; got {distinct_freqs}.") + + distinct_tzs = set([i.tz for i in idxs]) + if len(distinct_tzs) != 1 and ignore_tz is False: + raise ValueError(f"Indices must have equal timezones; got {distinct_tzs}.") + + empty_idx = [len(i) == 0 for i in idxs] + if any(empty_idx): + return pd.DatetimeIndex([]) + + # If we land here, we have at least 2 indices, all are not empty. + + distinct_sod = set([i[0].time() for i in idxs]) + if len(distinct_sod) != 1 and ignore_start_of_day is False: + raise ValueError(f"Indices must have equal start-of-day; got {distinct_sod}.") + for i in range(len(idxs)): + if len(distinct_sod) != 1 and longer_or_shorter(idxs[i].freq, "D") == -1: + raise ValueError( + "Downsample all indices to daily-or-longer, or trim them so they have the same start-of-day, before attempting to calculate the intersection" + ) + + freq, name, tz = [], [], [] + for i in range(len(idxs)): + freq.append(idxs[i].freq) + name.append(idxs[i].name) + tz.append(idxs[i].tz) + + longest_freq = freq[0] + if ignore_freq is True and len(distinct_freqs) != 1: + # Find the longest frequency + longest_freq = longest(*freq) + # trim datetimeindex + for i in range(len(idxs)): + # if idxs[i].freq is not the same as longest freq, we trim idxs[i] + if idxs[i].freq != longest_freq: + idxs[i] = tools.trim.index(idxs[i], longest_freq) + + if ignore_tz is True and len(distinct_tzs) != 1: + # set timezone to none for all values + for i in range(len(idxs)): + idxs[i] = idxs[i].tz_localize(None) + + if ignore_start_of_day is True and len(distinct_sod) != 1: + # Save a copy of the original hours and minutes + start_of_day = [x[0].time() for x in idxs] + # Set the time components to midnight for each index in the list + idxs = [index.normalize() for index in idxs] + + # Calculation is cumbersome: pandas DatetimeIndex.intersection not working correctly on timezone-aware indices (#46702) + values = set(idxs[0]) + # intersection is not working on datetimeindex with different freq->we need to use mask + for i in idxs[1:]: + values = values.intersection(set(i)) + values = sorted(values) + + if len(values) == 0: + return tuple([pd.DatetimeIndex([]) for _i in idxs]) + + idxs_out = [] + for i in range(len(idxs)): + start = min(values) + # end = stamp(start, longest_freq._prefix) + end = max(values) + end = stamp(end, longest_freq) + + if ignore_start_of_day is True: + start = datetime.combine(pd.to_datetime(start).date(), start_of_day[i]) + end = datetime.combine(pd.to_datetime(end).date(), start_of_day[i]) + # inclusive = "left" + + idxs_out.append( + pd.date_range( + start=start, + end=end, + freq=freq[i], + name=name[i], + tz=tz[i], + inclusive="left", + ) + ) + + return tuple(idxs_out) + + def frames( - *frames: Union[pd.Series, pd.DataFrame] + *frames: Union[pd.Series, pd.DataFrame], + ignore_freq: bool = False, + ignore_tz: bool = False, + ignore_start_of_day: bool = False, ) -> List[Union[pd.Series, pd.DataFrame]]: """Intersect several dataframes and/or series. @@ -66,6 +200,15 @@ def frames( ---------- *frames : pd.Series and/or pd.DataFrame The frames to intersect. + ignore_freq: bool, optional (default: False) + If True, do the intersection even if the frequencies do not match; drop the + time periods that do not (fully) exist in either of the frames. + ignore_tz: bool, optional (default: False) + If True, ignore the timezones; perform the intersection using 'wall time'. + ignore_start_of_day: bool, optional (default: False) + If True, perform the intersection even if the frames have a different start-of-day. + The start-of-day of the original frames is preserved, even if the frequency is shorter + than daily. Returns ------- @@ -77,5 +220,10 @@ def frames( The indices must have equal frequency, timezone, start-of-day. Otherwise, an error is raised. If there is no overlap, empty frames are returned. """ - common_index = indices(*[fr.index for fr in frames]) - return [fr.loc[common_index] for fr in frames] + new_idxs = indices_flex( + *[fr.index for fr in frames], + ignore_freq=ignore_freq, + ignore_tz=ignore_tz, + ignore_start_of_day=ignore_start_of_day, + ) + return [fr.loc[idx] for idx, fr in zip(new_idxs, frames)] diff --git a/portfolyo/core/shared/concat.py b/portfolyo/tools2/concat.py similarity index 97% rename from portfolyo/core/shared/concat.py rename to portfolyo/tools2/concat.py index 5f105d2..bcb5dce 100644 --- a/portfolyo/core/shared/concat.py +++ b/portfolyo/tools2/concat.py @@ -5,11 +5,11 @@ import pandas as pd from portfolyo import tools -from ..pfstate import PfState -from ..pfline.enums import Structure +from ..core.pfstate import PfState +from ..core.pfline.enums import Structure -from ..pfline import PfLine, create -from .. import pfstate +from ..core.pfline import PfLine, create +from ..core import pfstate def general(pfl_or_pfs: Iterable[PfLine | PfState]) -> None: diff --git a/portfolyo/tools2/intersect.py b/portfolyo/tools2/intersect.py new file mode 100644 index 0000000..8b0664c --- /dev/null +++ b/portfolyo/tools2/intersect.py @@ -0,0 +1,47 @@ +from portfolyo.tools.intersect import indices_flex +from ..core.pfline import PfLine +from ..core.pfstate import PfState +from typing import List, Union + +import pandas as pd + + +def indexable( + *frames: Union[pd.Series, pd.DataFrame, PfLine, PfState], + ignore_freq: bool = False, + ignore_tz: bool = False, + ignore_start_of_day: bool = False, +) -> List[Union[pd.Series, pd.DataFrame, PfLine, PfState]]: + """Intersect several dataframes and/or series. + + Parameters + ---------- + *frames : pd.Series and/or pd.DataFrame and/or PfLines and/or PfStates + The frames to intersect. + ignore_freq: bool, optional (default: False) + If True, do the intersection even if the frequencies do not match; drop the + time periods that do not (fully) exist in either of the frames. + ignore_tz: bool, optional (default: False) + If True, ignore the timezones; perform the intersection using 'wall time'. + ignore_start_of_day: bool, optional (default: False) + If True, perform the intersection even if the frames have a different start-of-day. + The start-of-day of the original frames is preserved, even if the frequency is shorter + than daily. + + Returns + ------- + list of series and/or dataframes + As input, but trimmed to their intersection. + + Notes + ----- + The indices must have equal frequency, timezone, start-of-day. Otherwise, an error + is raised. If there is no overlap, empty frames are returned. + """ + new_idxs = indices_flex( + *[fr.index for fr in frames], + ignore_freq=ignore_freq, + ignore_tz=ignore_tz, + ignore_start_of_day=ignore_start_of_day, + ) + return [fr.loc[idx] for idx, fr in zip(new_idxs, frames)] diff --git a/portfolyo/tools2/plot.py b/portfolyo/tools2/plot.py new file mode 100644 index 0000000..59fcaa6 --- /dev/null +++ b/portfolyo/tools2/plot.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Dict + +import matplotlib +import numpy as np +from matplotlib import pyplot as plt + +from portfolyo.core.shared.plot import defaultkwargs + +from .. import tools +from .. import visualize as vis + +from ..core.pfstate import PfState + + +def plot_pfstates(dic: Dict[str, PfState], freq: str = "MS") -> plt.Figure: + """Plot multiple PfState instances. + + Parameters + ---------- + dic : Dict[str, PfState] + Dictionary with PfState instances as values, and their names as the keys. + + Returns + ------- + plt.Figure + The figure object to which the instances were plotted. + """ + + gridspec = {"width_ratios": [0.3, 1, 1], "height_ratios": [4, 1] * len(dic)} + figsize = (15, 5 * len(dic)) + fig, axes = plt.subplots(len(dic) * 2, 3, gridspec_kw=gridspec, figsize=figsize) + axesgroups = axes.flatten().reshape((len(dic), 6)) + + # Share x axes. + sharex = axesgroups[:, (1, 2, 4)].flatten() + for ax1, ax2 in zip(sharex[1:], sharex[:-1]): + ax1.sharex(ax2) + # Share y axes. + sharey = axesgroups[:, 2] + for ax1, ax2 in zip(sharey[1:], sharey[:-1]): + ax1.sharey(ax2) + + # TODO: resample all to have same index (frequency and length). + + for i, ((pfname, pfs), axes) in enumerate(zip(dic.items(), axesgroups)): + # If freq is MS or longer: use categorical axes. Plot volumes in MWh. + # If freq is D or shorter: use time axes. Plot volumes in MW. + is_category = tools.freq.shortest(pfs.index.freq, "MS") == "MS" + + # Portfolio name. + axes[0].text( + 0, + 1, + pfname.replace(" ", "\n"), + fontsize=12, + fontweight="bold", + verticalalignment="top", + horizontalalignment="left", + ) + axes[0].axis("off") + + # Volumes. + if is_category: + s, kwargs = -1 * pfs.offtakevolume.q, defaultkwargs("q", is_category) + else: + s, kwargs = -1 * pfs.offtakevolume.w, defaultkwargs("w", is_category) + vis.plot_timeseries(axes[1], s, **kwargs) + + # Sourced fraction. + vis.plot_timeseries( + axes[2], pfs.sourcedfraction, **defaultkwargs("f", is_category) + ) + + # Empty. + axes[3].axis("off") + + # Procurement Price. + vis.plot_timeseries(axes[4], pfs.pnl_cost.p, **defaultkwargs("p", is_category)) + + # Empty. + axes[5].axis("off") + + # Tick formatting. + axes[2].yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(1.0)) + axes[1].yaxis.set_major_formatter( + matplotlib.ticker.FuncFormatter( + lambda x, p: "{:,.0f}".format(x).replace(",", " ") + ) + ) + + for a, ax in enumerate(axes): + if i == 0 and a in [1, 2]: + ax.xaxis.set_tick_params(labelbottom=False, labeltop=True, pad=25) + else: + ax.xaxis.set_tick_params(labelbottom=False, labeltop=False) + + if i == 0: + axes[1].set_title("Offtake Volume &\nprocurement price", y=1.27) + axes[2].set_title("Sourced fraction", y=1.27) + + return + draw_horizontal_lines(fig, axes) # draw horizontal lines between portfolios + + +def draw_horizontal_lines(fig, axes): + """Function to draw horizontal lines between multiple portfolios. + This function does not return anything, but tries to plot a 2D line after every 2 axes, eg. + after (0,2), (0,4),... beacuse each portfolio requires 2x4 axes in the fig (where rows=2, columns=4). + + Parameters + ---------- + fig : plt.subplots() + axes : plt.subplots() + """ + # rearange the axes for no overlap + fig.tight_layout() + + # Get the bounding boxes of the axes including text decorations + r = fig.canvas.get_renderer() + bboxes = np.array( + [ + ax.get_tightbbox(r).transformed(fig.transFigure.inverted()) + for ax in axes.flat + ], + matplotlib.transforms.Bbox, + ).reshape(axes.shape) + + """TO CORRECT: the horizontal line is not exactly in the middle of two graphs. + It is more inclined towards the second or next graph in the queue. + Each pftstate has 4x4 grid and this is plotted in the same graph, but as subgraphs. + """ + + # Get the minimum and maximum extent, get the coordinate half-way between those + ymax = ( + np.array(list(map(lambda b: b.y1, bboxes.flat))).reshape(axes.shape).max(axis=1) + ) + ymin = ( + np.array(list(map(lambda b: b.y0, bboxes.flat))).reshape(axes.shape).min(axis=1) + ) + ys = np.c_[ymax[2:-1:2], ymin[1:-2:2]].mean(axis=1) + ys = [ymax[0], *ys] + + # Draw a horizontal lines at those coordinates + for y in ys: + line = plt.Line2D([0, 1], [y, y], transform=fig.transFigure, color="black") + fig.add_artist(line) diff --git a/setup.cfg b/setup.cfg index 4e57090..4edd93f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ max-line-length = 120 ignore = E501, W503, E202, E226 [tool:pytest] -addopts = --cov=. +#addopts = --cov=. markers = only_on_pr: marks tests as slow (select with -m only_on_pr and deselect with -m "not only_on_pr") pythonpath = ./tests diff --git a/tests/core/shared/test_concat_error_cases.py b/tests/core/shared/test_concat_error_cases.py index a5eb551..4c8e8ff 100644 --- a/tests/core/shared/test_concat_error_cases.py +++ b/tests/core/shared/test_concat_error_cases.py @@ -7,7 +7,7 @@ from portfolyo import dev from portfolyo.core.pfline.enums import Kind from portfolyo.core.pfstate.pfstate import PfState -from portfolyo.core.shared import concat +from portfolyo.tools2 import concat def test_general(): diff --git a/tests/core/shared/test_concat_pfline.py b/tests/core/shared/test_concat_pfline.py index f066602..0a862dc 100644 --- a/tests/core/shared/test_concat_pfline.py +++ b/tests/core/shared/test_concat_pfline.py @@ -3,7 +3,7 @@ import pandas as pd import pytest from portfolyo import dev -from portfolyo.core.shared import concat +from portfolyo.tools2 import concat TESTCASES2 = [ # whole idx, freq, where diff --git a/tests/core/shared/test_concat_pfstate.py b/tests/core/shared/test_concat_pfstate.py index 3ce923e..8f2c2de 100644 --- a/tests/core/shared/test_concat_pfstate.py +++ b/tests/core/shared/test_concat_pfstate.py @@ -3,7 +3,7 @@ import pandas as pd import pytest from portfolyo import dev -from portfolyo.core.shared import concat +from portfolyo.tools2 import concat TESTCASES2 = [ # whole idx, freq, where diff --git a/tests/tools/test_intersect.py b/tests/tools/test_intersect.py index fbb83f4..fed0837 100644 --- a/tests/tools/test_intersect.py +++ b/tests/tools/test_intersect.py @@ -154,7 +154,7 @@ def test_intersect_nooverlap(indexorframe: str, tz: str, freq: str, starttime: s get_idx("2020-01-01", starttime, tz, freq, "2022-01-01"), get_idx("2023-01-01", starttime, tz, freq, "2025-01-01"), ] - do_test_intersect(indexorframe, idxs, None, "", tz, freq) + do_test_intersect(indexorframe, idxs, None, "", tz, freq, check_freq=False) def do_test_intersect( @@ -164,12 +164,21 @@ def do_test_intersect( expected_starttime: str = None, expected_tz: str = None, expected_freq: str = None, + **kwargs, ): if indexorframe == "idx": - do_test_fn = do_test_intersect_index + do_test_intersect_index( + idxs, expected_startdate, expected_starttime, expected_tz, expected_freq + ) else: - do_test_fn = do_test_intersect_frame - do_test_fn(idxs, expected_startdate, expected_starttime, expected_tz, expected_freq) + do_test_intersect_frame( + idxs, + expected_startdate, + expected_starttime, + expected_tz, + expected_freq, + **kwargs, + ) def do_test_intersect_index( @@ -200,17 +209,31 @@ def do_test_intersect_frame( expected_starttime: str = None, expected_tz: str = None, expected_freq: str = None, + ignore_freq: bool = False, + ignore_start_of_day: bool = False, + ignore_tz: bool = False, + **kwargs, ): frames = get_frames(idxs) # Error case. if type(expected_startdate) is type and issubclass(expected_startdate, Exception): with pytest.raises(expected_startdate): - tools.intersect.frames(*frames) + tools.intersect.frames( + *frames, + ignore_start_of_day=ignore_start_of_day, + ignore_tz=ignore_tz, + ignore_freq=ignore_freq, + ) return # Normal case. - result_frames = tools.intersect.frames(*frames) + result_frames = tools.intersect.frames( + *frames, + ignore_freq=ignore_freq, + ignore_start_of_day=ignore_start_of_day, + ignore_tz=ignore_tz, + ) expected_index = get_idx( expected_startdate, expected_starttime, expected_tz, expected_freq ) @@ -218,6 +241,6 @@ def do_test_intersect_frame( for result, expected in zip(result_frames, expected_frames): if isinstance(result, pd.Series): - testing.assert_series_equal(result, expected) + testing.assert_series_equal(result, expected, **kwargs) else: - testing.assert_frame_equal(result, expected) + testing.assert_frame_equal(result, expected, **kwargs) diff --git a/tests/tools/test_intersect_flex.py b/tests/tools/test_intersect_flex.py new file mode 100644 index 0000000..29e1281 --- /dev/null +++ b/tests/tools/test_intersect_flex.py @@ -0,0 +1,303 @@ +from typing import Iterable, Union + +import pandas as pd +import pytest + +from portfolyo import testing, tools + +COMMON_END = "2022-02-02" + +TESTCASES = [ # startdates, freq, expected_startdate + # One starts at first day of year. + (("2020-01-01", "2020-01-20"), "15T", "2020-01-20"), + (("2020-01-01", "2020-01-20"), "15T", "2020-01-20"), + (("2020-01-01", "2020-01-20"), "H", "2020-01-20"), + (("2020-01-01", "2020-01-20"), "H", "2020-01-20"), + (("2020-01-01", "2020-01-20"), "D", "2020-01-20"), + (("2020-01-01", "2020-01-20"), "D", "2020-01-20"), + (("2020-01-01", "2020-03-01"), "MS", "2020-03-01"), + (("2020-01-01", "2020-03-01"), "MS", "2020-03-01"), + (("2020-01-01", "2020-04-01"), "QS", "2020-04-01"), + (("2020-01-01", "2020-04-01"), "QS", "2020-04-01"), + (("2020-01-01", "2021-01-01"), "AS", "2021-01-01"), + (("2020-01-01", "2021-01-01"), "AS", "2021-01-01"), + # Both start in middle of year. + (("2020-04-21", "2020-06-20"), "15T", "2020-06-20"), + (("2020-04-21", "2020-06-20"), "15T", "2020-06-20"), + (("2020-04-21", "2020-06-20"), "H", "2020-06-20"), + (("2020-04-21", "2020-06-20"), "H", "2020-06-20"), + (("2020-04-21", "2020-06-20"), "D", "2020-06-20"), + (("2020-04-21", "2020-06-20"), "D", "2020-06-20"), +] + +COMMON_END_2 = "2023-01-01" +TESTCASES_2 = [ # startdates, freq, expected_dates + # One starts at first day of year. + (("2020-01-01", "2020-01-20"), ("15T", "H"), "2020-01-20"), + (("2020-01-01", "2020-01-20"), ("15T", "D"), "2020-01-20"), + (("2022-04-01", "2021-02-01"), ("H", "MS"), "2022-04-01"), + (("2020-01-01", "2020-04-01"), ("H", "QS"), "2020-04-01"), + (("2020-01-01", "2021-01-01"), ("D", "AS"), "2021-01-01"), + # Both start in middle of year. + (("2020-04-21", "2020-06-20"), ("15T", "H"), "2020-06-20"), + (("2020-04-21", "2020-06-20"), ("15T", "D"), "2020-06-20"), + (("2020-04-21", "2020-07-01"), ("H", "MS"), "2020-07-01"), + (("2020-04-21", "2020-07-01"), ("H", "QS"), "2020-07-01"), + (("2020-04-21", "2021-01-01"), ("D", "AS"), "2021-01-01"), +] + + +def get_idx( + startdate: str, + starttime: str, + tz: str, + freq: str, + enddate: str, +) -> pd.DatetimeIndex: + # Empty index. + if startdate is None: + return pd.DatetimeIndex([], freq=freq, tz=tz) + # Normal index. + ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) + ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) + return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize(("startdates", "freq", "expected_startdate"), TESTCASES) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +# @pytest.mark.parametrize("indexorframe", ["idx", "fr"]) +def test_intersect_flex_ignore_start_of_day( + # indexorframe: str, + startdates: Iterable[str], + starttime: str, + tz: str, + freq: str, + expected_startdate: str, +): + otherstarttime = "00:00" if starttime == "06:00" else "06:00" + idxs = [ + get_idx( + startdates[0], + starttime, + tz, + freq, + COMMON_END, + ), + get_idx( + startdates[1], + otherstarttime, + tz, + freq, + COMMON_END, + ), + ] + do_test_intersect( + "idx", + idxs, + ValueError if freq == "15T" or freq == "H" else expected_startdate, + expected_tz=tz, + expected_freq=freq, + expected_starttime=starttime, + expected_otherstarttime=otherstarttime, + expected_othertz=tz, + expected_otherfreq=freq, + enddate=COMMON_END, + ignore_start_of_day=True, + ) + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +@pytest.mark.parametrize(("startdates", "freq", "expected_startdate"), TESTCASES) +# @pytest.mark.parametrize("indexorframe", ["idx", "fr"]) +def test_intersect_flex_ignore_tz( + # indexorframe: str, + startdates: Iterable[str], + starttime: str, + tz: str, + freq: str, + expected_startdate: str, +): + othertz = None if tz == "Europe/Berlin" else "Europe/Berlin" + idxs = [ + get_idx(startdates[0], starttime, tz, freq, COMMON_END), + get_idx(startdates[1], starttime, othertz, freq, COMMON_END), + ] + do_test_intersect( + "idx", + idxs, + expected_startdate, + expected_tz=tz, + expected_freq=freq, + expected_starttime=starttime, + expected_otherstarttime=starttime, + expected_othertz=othertz, + expected_otherfreq=freq, + enddate=COMMON_END, + ignore_tz=True, + ) + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize(("startdates", "freq", "expected_startdate"), TESTCASES_2) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +def test_intersect_flex_ignore_freq( + # indexorframe: str, + startdates: Iterable[str], + starttime: str, + tz: str, + freq: Iterable[str], + expected_startdate: str, +): + """Test if intersection of indices with distinct frequencies gives correct result.""" + + idxs = [ + get_idx(startdates[0], starttime, tz, freq[0], COMMON_END_2), + get_idx(startdates[1], starttime, tz, freq[1], COMMON_END_2), + ] + do_test_intersect( + "idx", + idxs, + expected_startdate, + expected_tz=tz, + expected_freq=freq[0], + expected_starttime=starttime, + expected_otherstarttime=starttime, + expected_othertz=tz, + expected_otherfreq=freq[1], + enddate=COMMON_END_2, + ignore_freq=True, + ) + + +@pytest.mark.parametrize("tz", [None, "Europe/Berlin", "Asia/Kolkata"]) +@pytest.mark.parametrize(("startdates", "freq", "expected_startdate"), TESTCASES_2) +@pytest.mark.parametrize("starttime", ["00:00", "06:00"]) +def test_ignore_all( # indexorframe: str, + startdates: Iterable[str], + starttime: str, + tz: str, + freq: Iterable[str], + expected_startdate: str, +): + otherstarttime = "00:00" if starttime == "06:00" else "06:00" + othertz = None if tz == "Europe/Berlin" else "Europe/Berlin" + idxs = [ + get_idx(startdates[0], starttime, tz, freq[0], COMMON_END_2), + get_idx(startdates[1], otherstarttime, othertz, freq[1], COMMON_END_2), + ] + do_test_intersect( + "idx", + idxs, + ( + ValueError + if freq[0] == "15T" or freq[0] == "H" or freq[1] == "15T" or freq[1] == "H" + else expected_startdate + ), + expected_tz=tz, + expected_freq=freq[0], + expected_starttime=starttime, + expected_otherstarttime=otherstarttime, + expected_othertz=othertz, + expected_otherfreq=freq[1], + enddate=COMMON_END_2, + ignore_freq=True, + ignore_start_of_day=True, + ignore_tz=True, + ) + + +def do_test_intersect( + indexorframe: str, + idxs: Iterable[pd.DatetimeIndex], + expected_startdate: Union[str, Exception], + expected_starttime: str = None, + expected_tz: str = None, + expected_freq: str = None, + expected_otherstarttime: str = None, + expected_othertz: str = None, + expected_otherfreq: str = None, + enddate: str = None, + ignore_start_of_day: bool = False, + ignore_tz: bool = False, + ignore_freq: bool = False, +): + if indexorframe == "idx": + do_test_intersect_index( + idxs, + expected_startdate, + expected_starttime, + expected_tz, + expected_freq, + expected_otherstarttime, + expected_othertz, + expected_otherfreq, + enddate, + ignore_start_of_day, + ignore_tz, + ignore_freq, + ) + + +def do_test_intersect_index( + idxs: Iterable[pd.DatetimeIndex], + expected_startdate: Union[str, Exception], + expected_starttime: str = None, + expected_tz: str = None, + expected_freq: str = None, + expected_otherstarttime: str = None, + expected_othertz: str = None, + expected_otherfreq: str = None, + enddate: str = None, + ignore_start_of_day: bool = False, + ignore_tz: bool = False, + ignore_freq: bool = False, +): + # Error case. + if isinstance(expected_startdate, type) and issubclass( + expected_startdate, Exception + ): + with pytest.raises(expected_startdate): + tools.intersect.indices_flex( + *idxs, + ignore_start_of_day=False, + ignore_tz=False, + ignore_freq=ignore_freq, + ) + return + # Normal case. + out_a, out_b = tools.intersect.indices_flex( + *idxs, + ignore_start_of_day=ignore_start_of_day, + ignore_tz=ignore_tz, + ignore_freq=ignore_freq, + ) + expected_a = get_idx( + expected_startdate, + expected_starttime, + expected_tz, + expected_freq, + enddate, + ) + expected_b = get_idx( + expected_startdate, + expected_otherstarttime, + expected_othertz, + expected_otherfreq, + enddate, + ) + testing.assert_index_equal(out_a, expected_a) + testing.assert_index_equal(out_b, expected_b) + + +def test_intersect_flex_dst(): + """Test if intersection keeps working if DST-boundary is right at end.""" + i1 = pd.date_range("2020", "2020-03-29", freq="D", tz="Europe/Berlin") + i2 = pd.date_range("2020", "2020-03-30", freq="D", tz="Europe/Berlin") + + expected = pd.date_range("2020", "2020-03-29", freq="D", tz="Europe/Berlin") + + result1, result2 = tools.intersect.indices_flex(i1, i2) + testing.assert_index_equal(result1, expected) + testing.assert_index_equal(result2, expected) diff --git a/tests/tools/test_intersect_flex_frame.py b/tests/tools/test_intersect_flex_frame.py new file mode 100644 index 0000000..5c42b05 --- /dev/null +++ b/tests/tools/test_intersect_flex_frame.py @@ -0,0 +1,175 @@ +import pandas as pd +import pytest + +from portfolyo import testing, tools + + +@pytest.mark.parametrize("types", ["series", "df"]) +@pytest.mark.parametrize("ignore_tz", [True, False]) +def test_frames_ignore_tz(types: str, ignore_tz: bool): + idx_a = pd.date_range( + "2020", "2022", freq="MS", inclusive="left", tz="Europe/Berlin" + ) + a = pd.Series(range(0, 24), idx_a) + + idx_b = pd.date_range("2020-02", "2021-09", freq="MS", inclusive="left") + b = pd.Series(range(0, 19), idx_b) + + exp_idx_a = pd.date_range( + "2020-02", "2021-09", freq="MS", inclusive="left", tz="Europe/Berlin" + ) + exp_idx_b = idx_b + exp_a = pd.Series(range(1, 20), exp_idx_a) + exp_b = pd.Series(range(0, 19), exp_idx_b) + + if types == "series": + if not ignore_tz: + with pytest.raises(ValueError): + _ = tools.intersect.frames(a, b, ignore_tz=ignore_tz) + return + result_a, result_b = tools.intersect.frames(a, b, ignore_tz=ignore_tz) + testing.assert_series_equal(result_a, exp_a) + testing.assert_series_equal(result_b, exp_b) + else: + a, b = pd.DataFrame({"col_a": a}), pd.DataFrame({"col_b": b}) + if not ignore_tz: + with pytest.raises(ValueError): + _ = tools.intersect.frames(a, b, ignore_tz=ignore_tz) + return + exp_a, exp_b = pd.DataFrame({"col_a": exp_a}), pd.DataFrame({"col_b": exp_b}) + result_a, result_b = tools.intersect.frames(a, b, ignore_tz=ignore_tz) + testing.assert_frame_equal(result_a, exp_a) + testing.assert_frame_equal(result_b, exp_b) + + +@pytest.mark.parametrize("types", ["series", "df"]) +@pytest.mark.parametrize("ignore_start_of_day", [True, False]) +def test_frames_ignore_start_of_day(types: str, ignore_start_of_day: bool): + idx_a = pd.date_range("2020 00:00", "2022 00:00", freq="MS", inclusive="left") + a = pd.Series(range(0, 24), idx_a) + + idx_b = pd.date_range("2020-02 06:00", "2021-09 06:00", freq="MS", inclusive="left") + b = pd.Series(range(0, 19), idx_b) + + exp_idx_a = pd.date_range( + "2020-02 00:00", "2021-09 00:00", freq="MS", inclusive="left" + ) + exp_idx_b = idx_b + exp_a = pd.Series(range(1, 20), exp_idx_a) + exp_b = pd.Series(range(0, 19), exp_idx_b) + if types == "series": + if not ignore_start_of_day: + with pytest.raises(ValueError): + _ = tools.intersect.frames( + a, b, ignore_start_of_day=ignore_start_of_day + ) + return + result_a, result_b = tools.intersect.frames( + a, b, ignore_start_of_day=ignore_start_of_day + ) + testing.assert_series_equal(result_a, exp_a) + testing.assert_series_equal(result_b, exp_b) + else: + a, b = pd.DataFrame({"col_a": a}), pd.DataFrame({"col_b": b}) + if not ignore_start_of_day: + with pytest.raises(ValueError): + _ = tools.intersect.frames( + a, b, ignore_start_of_day=ignore_start_of_day + ) + return + exp_a, exp_b = pd.DataFrame({"col_a": exp_a}), pd.DataFrame({"col_b": exp_b}) + result_a, result_b = tools.intersect.frames( + a, b, ignore_start_of_day=ignore_start_of_day + ) + testing.assert_frame_equal(result_a, exp_a) + testing.assert_frame_equal(result_b, exp_b) + + +@pytest.mark.parametrize("types", ["series", "df"]) +@pytest.mark.parametrize("ignore_freq", [True, False]) +def test_frames_ignore_freq(types: str, ignore_freq: bool): + idx_a = pd.date_range("2022-04-01", "2024-07-01", freq="QS", inclusive="left") + a = pd.Series(range(0, 9), idx_a) + + idx_b = pd.date_range("2021-01-01", "2024-01-01", freq="AS", inclusive="left") + b = pd.Series(range(0, 3), idx_b) + + exp_idx_a = pd.date_range("2023-01-01", "2024-01-01", freq="QS", inclusive="left") + exp_idx_b = pd.date_range("2023-01-01", "2024-01-01", freq="AS", inclusive="left") + exp_a = pd.Series(range(3, 7), exp_idx_a) + exp_b = pd.Series(range(2, 3), exp_idx_b) + if types == "series": + if not ignore_freq: + with pytest.raises(ValueError): + _ = tools.intersect.frames(a, b, ignore_freq=ignore_freq) + return + result_a, result_b = tools.intersect.frames(a, b, ignore_freq=ignore_freq) + testing.assert_series_equal(result_a, exp_a) + testing.assert_series_equal(result_b, exp_b) + else: + a, b = pd.DataFrame({"col_a": a}), pd.DataFrame({"col_b": b}) + if not ignore_freq: + with pytest.raises(ValueError): + _ = tools.intersect.frames(a, b, ignore_freq=ignore_freq) + return + exp_a, exp_b = pd.DataFrame({"col_a": exp_a}), pd.DataFrame({"col_b": exp_b}) + result_a, result_b = tools.intersect.frames(a, b, ignore_freq=ignore_freq) + testing.assert_frame_equal(result_a, exp_a) + testing.assert_frame_equal(result_b, exp_b) + + +@pytest.mark.parametrize("types", ["series", "df"]) +@pytest.mark.parametrize("ignore_all", [True, False]) +def test_frames_ignore_all(types: str, ignore_all: bool): + idx_a = pd.date_range( + "2022-04-01 00:00", + "2024-07-01 00:00", + freq="QS", + tz="Europe/Berlin", + inclusive="left", + ) + a = pd.Series(range(0, 9), idx_a) + + idx_b = pd.date_range( + "2021-01-01 06:00", "2024-01-01 06:00", freq="AS", inclusive="left" + ) + b = pd.Series(range(0, 3), idx_b) + + exp_idx_a = pd.date_range( + "2023-01-01 00:00", + "2024-01-01 00:00", + freq="QS", + tz="Europe/Berlin", + inclusive="left", + ) + exp_idx_b = pd.date_range( + "2023-01-01 06:00", "2024-01-01 06:00", freq="AS", inclusive="left" + ) + exp_a = pd.Series(range(3, 7), exp_idx_a) + exp_b = pd.Series(range(2, 3), exp_idx_b) + if types == "series": + if not ignore_all: + with pytest.raises(ValueError): + _ = tools.intersect.frames( + a, b, ignore_freq=False, ignore_start_of_day=False, ignore_tz=False + ) + return + result_a, result_b = tools.intersect.frames( + a, b, ignore_freq=True, ignore_start_of_day=True, ignore_tz=True + ) + testing.assert_series_equal(result_a, exp_a) + testing.assert_series_equal(result_b, exp_b) + else: + a, b = pd.DataFrame({"col_a": a}), pd.DataFrame({"col_b": b}) + if not ignore_all: + with pytest.raises(ValueError): + _ = tools.intersect.frames( + a, b, ignore_freq=False, ignore_start_of_day=False, ignore_tz=False + ) + return + exp_a, exp_b = pd.DataFrame({"col_a": exp_a}), pd.DataFrame({"col_b": exp_b}) + result_a, result_b = tools.intersect.frames( + a, b, ignore_freq=True, ignore_start_of_day=True, ignore_tz=True + ) + testing.assert_frame_equal(result_a, exp_a) + testing.assert_frame_equal(result_b, exp_b) diff --git a/tests/tools2/test_indexable.py b/tests/tools2/test_indexable.py new file mode 100644 index 0000000..6ca46ab --- /dev/null +++ b/tests/tools2/test_indexable.py @@ -0,0 +1,223 @@ +from typing import Union +import pandas as pd +import pytest + +import portfolyo as pf + + +def get_idx( + startdate: str, + starttime: str, + tz: str, + freq: str, + enddate: str, +) -> pd.DatetimeIndex: + # Empty index. + if startdate is None: + return pd.DatetimeIndex([], freq=freq, tz=tz) + # Normal index. + ts_start = pd.Timestamp(f"{startdate} {starttime}", tz=tz) + ts_end = pd.Timestamp(f"{enddate} {starttime}", tz=tz) + return pd.date_range(ts_start, ts_end, freq=freq, inclusive="left") + + +def create_obj( + series: pd.Series, name_obj: str +) -> Union[pd.DataFrame, pf.PfLine, pf.PfState]: + if name_obj == "pfline": + return pf.PfLine({"w": series}) + elif name_obj == "pfstate": + volume = pf.PfLine({"w": series}) + prices = pf.PfLine({"p": series}) + return pf.PfState(volume, prices) + else: + return pd.DataFrame({"col": series}) + + +def t_function(objtype: str): + if objtype == "series": + return pd.testing.assert_series_equal + elif objtype == "dataframe": + return pd.testing.assert_frame_equal + elif objtype == "pfline": + return pf.PfLine.__eq__ + else: + return pf.PfState.__eq__ + + +@pytest.mark.parametrize("first_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("second_obj", ["pfstate", "pfline", "series", "dataframe"]) +def test_intersect_freq_ignore( + first_obj: str, + second_obj: str, +): + """Test that intersection works properly on PfLines and/or PfStates with ignore_freq.""" + idx1 = get_idx("2022-04-01", "00:00", "Europe/Berlin", "QS", "2024-07-01") + s1 = pd.Series(range(len(idx1)), idx1) + + idx2 = get_idx("2021-01-01", "00:00", "Europe/Berlin", "MS", "2024-01-01") + s2 = pd.Series(range(len(idx2)), idx2) + + first = create_obj(s1, first_obj) if first_obj != "series" else s1 + second = create_obj(s2, second_obj) if second_obj != "series" else s2 + # Do intersection + intersect = pf.intersection(first, second, ignore_freq=True) + + # Expected results + expected_s1 = s1.iloc[:7] + expected_s2 = s2.iloc[15:48] + output_1 = ( + create_obj(expected_s1, first_obj) if first_obj != "series" else expected_s1 + ) + output_2 = ( + create_obj(expected_s2, second_obj) if second_obj != "series" else expected_s2 + ) + for a, b, objtype in zip([output_1, output_2], intersect, [first_obj, second_obj]): + fn = t_function(objtype) + fn(a, b) + + +@pytest.mark.parametrize("first_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("second_obj", ["pfstate", "pfline", "series", "dataframe"]) +def test_intersect_sod( + first_obj: str, + second_obj: str, +): + """Test that intersection works properly on PfLines and/or PfStates with ignore_sod.""" + idx1 = get_idx("2022-04-01", "00:00", "Europe/Berlin", "QS", "2024-07-01") + s1 = pd.Series(range(len(idx1)), idx1) + + idx2 = get_idx("2021-01-01", "06:00", "Europe/Berlin", "QS", "2024-01-01") + s2 = pd.Series(range(len(idx2)), idx2) + + first = create_obj(s1, first_obj) if first_obj != "series" else s1 + second = create_obj(s2, second_obj) if second_obj != "series" else s2 + # Do intersection + intersect = pf.intersection(first, second, ignore_start_of_day=True) + + # Expected results + expected_s1 = s1.iloc[:7] + expected_s2 = s2.iloc[5:12] + output_1 = ( + create_obj(expected_s1, first_obj) if first_obj != "series" else expected_s1 + ) + output_2 = ( + create_obj(expected_s2, second_obj) if second_obj != "series" else expected_s2 + ) + for a, b, objtype in zip([output_1, output_2], intersect, [first_obj, second_obj]): + fn = t_function(objtype) + fn(a, b) + + +@pytest.mark.parametrize("first_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("second_obj", ["pfstate", "pfline", "series", "dataframe"]) +def test_intersect_tz( + first_obj: str, + second_obj: str, +): + """Test that intersection works properly on PfLines and/or PfStates with ignore_tz.""" + idx1 = get_idx("2022-04-01", "00:00", "Europe/Berlin", "QS", "2024-07-01") + s1 = pd.Series(range(len(idx1)), idx1) + + idx2 = get_idx("2021-01-01", "00:00", None, "QS", "2024-01-01") + s2 = pd.Series(range(len(idx2)), idx2) + + first = create_obj(s1, first_obj) if first_obj != "series" else s1 + second = create_obj(s2, second_obj) if second_obj != "series" else s2 + # Do intersection + intersect = pf.intersection(first, second, ignore_tz=True) + + # Expected results + expected_s1 = s1.iloc[:7] + expected_s2 = s2.iloc[5:12] + output_1 = ( + create_obj(expected_s1, first_obj) if first_obj != "series" else expected_s1 + ) + output_2 = ( + create_obj(expected_s2, second_obj) if second_obj != "series" else expected_s2 + ) + for a, b, objtype in zip([output_1, output_2], intersect, [first_obj, second_obj]): + fn = t_function(objtype) + fn(a, b) + + +@pytest.mark.parametrize("first_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("second_obj", ["pfstate", "pfline", "series", "dataframe"]) +def test_intersect_ignore_all( + first_obj: str, + second_obj: str, +): + """Test that intersection works properly on PfLines and/or PfStates with ignore_all.""" + idx1 = get_idx("2022-04-01", "00:00", "Europe/Berlin", "QS", "2024-07-01") + s1 = pd.Series(range(len(idx1)), idx1) + + idx2 = get_idx("2021-01-01", "06:00", None, "MS", "2024-01-01") + s2 = pd.Series(range(len(idx2)), idx2) + + first = create_obj(s1, first_obj) if first_obj != "series" else s1 + second = create_obj(s2, second_obj) if second_obj != "series" else s2 + # Do intersection + intersect = pf.intersection( + first, second, ignore_freq=True, ignore_tz=True, ignore_start_of_day=True + ) + + # Expected results + expected_s1 = s1.iloc[:7] + expected_s2 = s2.iloc[15:48] + output_1 = ( + create_obj(expected_s1, first_obj) if first_obj != "series" else expected_s1 + ) + output_2 = ( + create_obj(expected_s2, second_obj) if second_obj != "series" else expected_s2 + ) + for a, b, objtype in zip([output_1, output_2], intersect, [first_obj, second_obj]): + fn = t_function(objtype) + fn(a, b) + + +@pytest.mark.parametrize("first_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("second_obj", ["pfstate", "pfline", "series", "dataframe"]) +@pytest.mark.parametrize("third_obj", ["pfstate", "pfline", "series", "dataframe"]) +def test_intersect_ignore_all_3obj( + first_obj: str, + second_obj: str, + third_obj: str, +): + """Test that intersection works properly on PfLines and/or PfStates with ignore_all.""" + idx1 = get_idx("2022-04-01", "00:00", "Europe/Berlin", "QS", "2024-07-01") + s1 = pd.Series(range(len(idx1)), idx1) + + idx2 = get_idx("2021-01-01", "06:00", None, "MS", "2024-01-01") + s2 = pd.Series(range(len(idx2)), idx2) + + idx3 = get_idx("2023-01-01", "00:00", "Asia/Kolkata", "AS", "2025-01-01") + s3 = pd.Series(range(len(idx3)), idx3) + + first = create_obj(s1, first_obj) if first_obj != "series" else s1 + second = create_obj(s2, second_obj) if second_obj != "series" else s2 + third = create_obj(s3, third_obj) if third_obj != "series" else s3 + + # Do intersection + intersect = pf.intersection( + first, second, third, ignore_freq=True, ignore_tz=True, ignore_start_of_day=True + ) + + # Expected results + expected_s1 = s1.iloc[3:7] + expected_s2 = s2.iloc[24:36] + expected_s3 = s3.iloc[:1] + output_1 = ( + create_obj(expected_s1, first_obj) if first_obj != "series" else expected_s1 + ) + output_2 = ( + create_obj(expected_s2, second_obj) if second_obj != "series" else expected_s2 + ) + output_3 = ( + create_obj(expected_s3, third_obj) if third_obj != "series" else expected_s3 + ) + + for a, b, objtype in zip( + [output_1, output_2, output_3], intersect, [first_obj, second_obj, third_obj] + ): + fn = t_function(objtype) + fn(a, b)