Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add numba engine for rolling apply #30151

Merged
merged 56 commits into from
Dec 27, 2019
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
3b9bff8
Add numba to import_optional_dependencies
Dec 8, 2019
9a302bf
Start adding keywords
Dec 8, 2019
0e9a600
Modify apply for numba and cython
Dec 9, 2019
36a77ed
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 9, 2019
dbb2a9b
Add numba as optional dependency
Dec 9, 2019
f0e9a4d
Add premil tests
Dec 9, 2019
1250aee
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 10, 2019
4e7fd1a
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 11, 2019
cb976cf
Add numba to requirements-dev, type and reorder signature in apply
Dec 11, 2019
45420bb
Move numba routines to its own file
Dec 11, 2019
17851cf
Adjust signature in top level function as well
Dec 11, 2019
20767ca
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 11, 2019
9619f8d
Generate requirements-dev.txt using script
Dec 11, 2019
66fa69c
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 13, 2019
b8908ea
Add skip test decorator, add numba to a few builds
Dec 13, 2019
135f2ad
black
Dec 13, 2019
34a5687
don't rejit a user's jitted function
Dec 13, 2019
6da8199
Add numba/cython comparison test
Dec 13, 2019
123f77e
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 17, 2019
54e74d1
Remove typing for now
Dec 17, 2019
04d3530
Remove sub description for doc failures?
Dec 17, 2019
4bbf587
Fix test function
Dec 17, 2019
f849bc7
test user predefined jit function, clarify docstring
Dec 17, 2019
0c30e48
Apply engine kwargs to function as well
Dec 17, 2019
c4c952e
Clairfy documentation
Dec 17, 2019
8645976
Clarify what engine_kwargs applies to
Dec 17, 2019
987c916
Start section for numba rolling apply
Dec 17, 2019
b775684
Lint
Dec 17, 2019
2e04e60
clarify note
Dec 17, 2019
9b20ff5
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 18, 2019
0c14033
Add apply function cache to save compiled numba functions
Dec 19, 2019
c7106dc
Add performance example
Dec 19, 2019
1640085
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 21, 2019
2846faf
Remove whitespace
Dec 21, 2019
5a645c0
Address lint errors and separate apply tests
Dec 22, 2019
6bac000
Add whatsnew note
Dec 22, 2019
6f1c73f
Skip apply tests for numba not installed, lint
Dec 22, 2019
a890337
Add typing
Dec 22, 2019
0a9071c
Add more typing
Dec 22, 2019
9d8d40b
Formatting cleanups
Dec 23, 2019
84c3491
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 24, 2019
a429206
Address Jeff's comments
Dec 24, 2019
5826ad9
Black
Dec 24, 2019
cf7571b
Add clarification
Dec 24, 2019
4bc9787
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 24, 2019
18eed60
Move function to module level
Dec 24, 2019
f715b55
move cache check higher up
Dec 24, 2019
6a765bf
Address Will's comments
Dec 24, 2019
af3fe50
Type Callable in generate_numba_apply_func
Dec 24, 2019
eb7b5e1
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 24, 2019
f7dfcf4
use ellipsis, cannot specify np.ndarray as well
Dec 24, 2019
a42a960
Remove trailing whitespace in apply docstring
Dec 24, 2019
d019830
Address Will's and Brock's comments
Dec 25, 2019
29d145f
Fix typing
Dec 25, 2019
248149c
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
Dec 26, 2019
a3da51e
Address followup comments
Dec 26, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/deps/azure-36-minimum_versions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- beautifulsoup4=4.6.0
- bottleneck=1.2.1
- jinja2=2.8
- numba=0.46.0
- numexpr=2.6.2
- numpy=1.13.3
- openpyxl=2.5.7
Expand Down
1 change: 1 addition & 0 deletions ci/deps/azure-windows-36.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- bottleneck
- fastparquet>=0.3.2
- matplotlib=3.0.2
- numba
- numexpr
- numpy=1.15.*
- openpyxl
Expand Down
1 change: 1 addition & 0 deletions doc/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ gcsfs 0.2.2 Google Cloud Storage access
html5lib HTML parser for read_html (see :ref:`note <optional_html>`)
lxml 3.8.0 HTML parser for read_html (see :ref:`note <optional_html>`)
matplotlib 2.2.2 Visualization
numba 0.46.0 Alternative execution engine for rolling operations
openpyxl 2.5.7 Reading / writing for xlsx files
pandas-gbq 0.8.0 Google Big Query access
psycopg2 PostgreSQL engine for sqlalchemy
Expand Down
47 changes: 47 additions & 0 deletions doc/source/user_guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ We provide a number of common statistical functions:
:meth:`~Rolling.cov`, Unbiased covariance (binary)
:meth:`~Rolling.corr`, Correlation (binary)

.. _stats.rolling_apply:

Rolling Apply
~~~~~~~~~~~~~

The :meth:`~Rolling.apply` function takes an extra ``func`` argument and performs
generic rolling computations. The ``func`` argument should be a single function
that produces a single value from an ndarray input. Suppose we wanted to
Expand All @@ -334,6 +339,48 @@ compute the mean absolute deviation on a rolling basis:
@savefig rolling_apply_ex.png
s.rolling(window=60).apply(mad, raw=True).plot(style='k')

.. versionadded:: 1.0

Additionally, :meth:`~Rolling.apply` can leverage `Numba <https://numba.pydata.org/>`__
jreback marked this conversation as resolved.
Show resolved Hide resolved
if installed as an optional dependency. The apply aggregation can be executed using Numba by specifying
``engine='numba'`` and ``engine_kwargs`` arguments (``raw`` must also be set to ``True``).
Numba will be applied in potentially two routines:

1. If ``func`` is a standard Python function, the engine will `JIT <http://numba.pydata.org/numba-doc/latest/user/overview.html>`__
the passed function. ``func`` can also be a JITed function in which case the engine will not JIT the function again.
2. The engine will JIT the for loop where the apply function is applied to each window.

The ``engine_kwargs`` argument is a dictionary of keyword arguments that will be passed into the
`numba.jit decorator <https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html#numba.jit>`__.
These keyword arguments will be applied to *both* the passed function (if a standard Python function)
and the apply for loop over each window. Currently only ``nogil``, ``nopython``, and ``parallel`` are supported,
and their default values are set to ``False``, ``True`` and ``False`` respectively.

.. note::

In terms of performance, **the first time a function is run using the Numba engine will be slow**
as Numba will have some function compilation overhead. However, ``rolling`` objects will cache
the function and subsequent calls will be fast. In general, the Numba engine is performant with
a larger amount of data points (e.g. 1+ million).

.. code-block:: ipython

In [1]: data = pd.Series(range(1_000_000))

In [2]: roll = data.rolling(10)

In [3]: def f(x):
...: return np.sum(x) + 5
# Run the first time, compilation time will affect performance
In [4]: %timeit -r 1 -n 1 roll.apply(f, engine='numba', raw=True) # noqa: E225
jreback marked this conversation as resolved.
Show resolved Hide resolved
jreback marked this conversation as resolved.
Show resolved Hide resolved
1.23 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
# Function is cached and performance will improve
In [5]: %timeit roll.apply(f, engine='numba', raw=True)
188 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [6]: %timeit roll.apply(f, engine='cython', raw=True)
3.92 s ± 59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

.. _stats.rolling_window:

Rolling windows
Expand Down
13 changes: 13 additions & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ You can use the alias ``"boolean"`` as well.
s = pd.Series([True, False, None], dtype="boolean")
s

.. _whatsnew_1000.numba_rolling_apply:

Using Numba in ``rolling.apply``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We've added an ``engine`` keyword to :meth:`~Rolling.apply` that allows the user to execute the
routine using `Numba <https://numba.pydata.org/>`__ instead of Cython. Using the Numba engine
can yield significant performance gains if the apply function can operate on numpy arrays and
jreback marked this conversation as resolved.
Show resolved Hide resolved
the data set is larger. For more details, see :ref:`rolling apply documentation <stats.rolling_apply>`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the data set is larger" here is pretty vague. is the perf gain a function of the array size or more about the user-defined function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat both, but more obvious with the data size. I can make this more specific.

(:issue:`28987`)

.. _whatsnew_1000.custom_window:

Defining custom windows for rolling operations
Expand Down Expand Up @@ -428,6 +439,8 @@ Optional libraries below the lowest tested version may still work, but are not c
+-----------------+-----------------+---------+
| matplotlib | 2.2.2 | |
+-----------------+-----------------+---------+
| numba | 0.46.0 | |
jreback marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an X for this

+-----------------+-----------------+---------+
| openpyxl | 2.5.7 | X |
+-----------------+-----------------+---------+
| pyarrow | 0.12.0 | X |
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dependencies:
- matplotlib>=2.2.2 # pandas.plotting, Series.plot, DataFrame.plot
- numexpr>=2.6.8
- scipy>=1.1
- numba>=0.46.0

# optional for io
- beautifulsoup4>=4.6.0 # pandas.read_html
Expand Down
1 change: 1 addition & 0 deletions pandas/compat/_optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"xlrd": "1.1.0",
"xlwt": "1.2.0",
"xlsxwriter": "0.9.8",
"numba": "0.46.0",
}


Expand Down
1 change: 1 addition & 0 deletions pandas/core/window/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _apply(
floor: int = 1,
is_weighted: bool = False,
name: Optional[str] = None,
use_numba_cache: Optional[bool] = False,
jreback marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""
Expand Down
100 changes: 100 additions & 0 deletions pandas/core/window/numba_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import types
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency


def make_rolling_apply(func, args, nogil, parallel, nopython):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker here since this is large enough, but would be nice to annotate this in a follow up

numba = import_optional_dependency("numba")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a doc-string that says what this function does (the parameters are already documented elsewhere, maybe just mention that)


if parallel:
loop_range = numba.prange
else:
loop_range = range

if isinstance(func, numba.targets.registry.CPUDispatcher):
# Don't jit a user passed jitted function
numba_func = func
else:

@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stuartarchibald sorry for the ping, but I see that generated_jit has been deprecated in numba 0.57. IIRC you helped me add this a while back and am lost on how to write this in terms of overload

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mroeschke no problem, I can try and help with this. I think it needs to look a bit like this (for reference, this is untested, I am just guessing from the context! Also, the pandas variant is obviously wrapped to close over some configuration which I've omitted, so consider this as the function body of make_rolling_apply. I've left comments inline to try and explain what's going on):

import types
import numpy as np
from numba.extending import overload, is_jitted
from numba import njit
import numba


# this provides a local definition to overload
def overload_target(window, *_args):
    # If JIT is disabled, this function will run, so write the implementation here!
    pass


nopython = True
nogil = True
parallel = False

# pretend this is an arg to `make_rolling_apply`
def func(window, *args):
    return window * 2 + args[0]


@overload(overload_target, jit_options={'nopython':nopython, 'nogil':nogil,
                                        'parallel':parallel})
def ol_overload_target(window, *_args):
    # This function "overloads" `overload_target`, whenever the Numba compiler
    # "sees" `overload_target` it will use this function.

    # Using `is_jitted` to avoid `isinstance` on
    # `numba.targets.registry.CPUDispatcher` as that may be considered an
    # internal Numba detail.
    if is_jitted(func):
        # it's already JIT compiled so just reference it
        overload_target_impl = func
    elif getattr(np, func.__name__, False) is func or isinstance(
        func, types.BuiltinFunctionType
    ):
        # it's a NumPy function or builtin so just reference it
        overload_target_impl = func
    else:
        # it's a Python function, so register it as JIT compilable and reference
        # that
        overload_target_impl = numba.jit(func, nopython=nopython, nogil=nogil)

    # This is the Numba implementation of the overload, it will just be JIT
    # compiled whenever the compiler "sees" a reference to "overload_target" in
    # code it is compiling.
    def impl(window, *_args):
        return overload_target_impl(window, *_args)

    return impl


# demo

@njit
def roll_apply(window, *_args):
    return overload_target(window, *_args)


print(roll_apply(np.arange(10.), 1.23))

@overload is basically saying to Numba "when you see this specific python function (the one in the first argument in the @overload decorator) use this implementation". The concept about there being a "typing" part that can be used to dispatch different variants based on type is exactly the same as in @generated_jit. The largest difference is what happens if the JIT compiler is turned off. In the case of @overload the python function being overloaded will run, i.e., the code just executes as would be expected in the interpreter. Whereas in the case of @generated_jit, because the pure python implementation and the Numba implementation are the same function, if you turn the JIT compiler off it will just break (the value returned when calling a @generated_jit function is a function implementing the Numba specialisation). Essentially, @generated_jit is like doing @overload but the function being decorated is also the function being overloaded.

Hope this helps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply! We had a PR recently that refactored this to use extending.register_jittable. Would that be a sufficient alternative? https://github.com/pandas-dev/pandas/pull/53455/files

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! I just took a look at the patch above, I think it'd work but think it might lose some of the dispatch ability offered by generated_jit/overload. As I understand it, the original code would have let a NumPy function or a built-in be passed in as the "user function", whereas I think the register_jittable version requires a user defined Python function. It may be that the register_jittable version is a sufficient alternative for the need/use cases in practice, in which case, it seems appropriate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great thanks for the context! Yeah this function should expect a custom UDF so thanks for the confirmation

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to get this resolved, thanks for confirming too! It sounds like the replacement above is appropriate. If there are any more issues/queries feel free to open issues on the Numba issue tracker (or ping here!).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stuartarchibald I'm running into a rolling apply issue with pandas 2.1.1 and numba 0.58 that might be related. Discussion is here:
https://numba.discourse.group/t/pandas-source-of-old-style-error-capturing-warning/2169/8

def numba_func(window, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jreback marked this conversation as resolved.
Show resolved Hide resolved
jf = func
else:
jf = numba.jit(func, nopython=nopython)
jreback marked this conversation as resolved.
Show resolved Hide resolved

def impl(window, *_args):
return jf(window, *_args)

return impl

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_apply(
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int,
) -> np.ndarray:
result = np.empty(len(begin))
for i in loop_range(len(result)):
start = begin[i]
stop = end[i]
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result

return roll_apply


def generate_numba_apply_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: Optional[Dict[str, bool]],
):
"""
Generate a numba jitted apply function specified by values from engine_kwargs.

1. jit the user's function
2. Return a rolling apply function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.
jreback marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit

Returns
-------
Numba function
"""

if engine_kwargs is None:
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)

if kwargs and nopython:
raise ValueError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)

return make_rolling_apply(func, args, nogil, parallel, nopython)
Loading