Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental Feature - Accelerated NumPy #672

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dist/
doc/build/
doc/source/**/generated/
arch/univariate/recursions.c
**/.DS_Store
3 changes: 2 additions & 1 deletion arch/covariance/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from functools import cached_property
from typing import SupportsInt, cast

import numpy as np
from pandas import DataFrame, Index
from pandas.util._decorators import Substitution

from arch.experimental import numpy as np
from arch.typing import ArrayLike, Float64Array
from arch.utility.array import AbstractDocStringInheritor, ensure1d, ensure2d

Expand Down Expand Up @@ -398,6 +398,7 @@ def cov(self) -> CovarianceEstimate:
sr = x.T @ x / df
w = self.kernel_weights
num_weights = w.shape[0]
x = np.asarray(self._x)
oss = np.zeros((k, k))
for i in range(1, num_weights):
oss += w[i] * (x[i:].T @ x[:-i]) / df
Expand Down
19 changes: 19 additions & 0 deletions arch/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .engine import (
LinAlgEngine,
NumpyEngine,
backend,
linalg,
numpy,
set_backend,
use_backend,
)

__all__ = [
"LinAlgEngine",
"NumpyEngine",
"backend",
"linalg",
"numpy",
"set_backend",
"use_backend",
]
139 changes: 139 additions & 0 deletions arch/experimental/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from contextlib import contextmanager
from typing import Any

_BACKEND_ENGINE = "numpy"
_SUPPORTED_ENGINES = ["numpy", "tensorflow", "cupy", "jax"]


def backend():
return _BACKEND_ENGINE


def set_backend(library_name):
"""
Set backend engine.

The function sets the backend engine in global level.

Parameters
----------
library_name : str
Library name. Default is `numpy`. Options are `numpy`, `tensorflow`,
`cupy` and `jax`.
"""
library_name = library_name.lower()
assert library_name in _SUPPORTED_ENGINES, (
"Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not "
f"{library_name}"
)
global _BACKEND_ENGINE
_BACKEND_ENGINE = library_name
return _BACKEND_ENGINE


@contextmanager
def use_backend(library_name="numpy"):
"""
NumPy engine selection.

The function is a context manager to enable users to switch to a
specific library as a replacement of NumPy in CPU.

Parameters
----------
library_name : str
Library name. Default is `numpy`. Options are `numpy`, `tensorflow`,
`cupy` and `jax`.
"""
library_name = library_name.lower()
assert library_name.lower() in _SUPPORTED_ENGINES, (
"Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not "
f"{library_name}"
)
global _BACKEND_ENGINE
_original = _BACKEND_ENGINE
try:
_BACKEND_ENGINE = library_name
if _BACKEND_ENGINE == "tensorflow":
import tensorflow.experimental.numpy as np

Check warning on line 58 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L58

Added line #L58 was not covered by tests

np.experimental_enable_numpy_behavior()

Check warning on line 60 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L60

Added line #L60 was not covered by tests
yield
finally:
_BACKEND_ENGINE = _original


class NumpyEngine:
"""
NumPy engine.
"""

def __getattribute__(self, __name: str) -> Any:
if __name == "name":
return _BACKEND_ENGINE

try:
if _BACKEND_ENGINE == "numpy":
import numpy as anp
elif _BACKEND_ENGINE == "tensorflow":
import tensorflow.experimental.numpy as anp

Check warning on line 79 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L79

Added line #L79 was not covered by tests
elif _BACKEND_ENGINE == "cupy":
import cupy as anp

Check warning on line 81 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L81

Added line #L81 was not covered by tests
elif _BACKEND_ENGINE == "jax":
import jax.numpy as anp
else:
raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}")

Check warning on line 85 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L85

Added line #L85 was not covered by tests
except ImportError:
raise ImportError(
"Library `numpy` cannot be imported from backend engine "
f"{_BACKEND_ENGINE}. Please make sure to install the library "
f"via `pip install {_BACKEND_ENGINE}`."
)

try:
return getattr(anp, __name)
except AttributeError:
raise AttributeError(
f"Cannot get attribute / function ({__name}) from numpy library in "
f"backend engine {_BACKEND_ENGINE}"
)


class LinAlgEngine:
"""
Linear algebra engine.
"""

def __getattribute__(self, __name: str) -> Any:
if __name == "name":
return _BACKEND_ENGINE

try:
if _BACKEND_ENGINE == "numpy":
import numpy.linalg as alinalg
elif _BACKEND_ENGINE == "tensorflow":
import tensorflow.linalg as alinalg

Check warning on line 115 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L115

Added line #L115 was not covered by tests
elif _BACKEND_ENGINE == "cupy":
import cupy.linalg as alinalg

Check warning on line 117 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L117

Added line #L117 was not covered by tests
elif _BACKEND_ENGINE == "jax":
import jax.numpy.linalg as alinalg
else:
raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}")

Check warning on line 121 in arch/experimental/engine.py

View check run for this annotation

Codecov / codecov/patch

arch/experimental/engine.py#L121

Added line #L121 was not covered by tests
except ImportError:
raise ImportError(
"Library `linalg` cannot be imported from backend engine "
f"{_BACKEND_ENGINE}. Please make sure to install the library "
f"via `pip install {_BACKEND_ENGINE}`."
)

try:
return getattr(alinalg, __name)
except AttributeError:
raise AttributeError(
f"Cannot get attribute / function ({__name}) from linalg library in "
f"backend engine {_BACKEND_ENGINE}"
)


numpy = NumpyEngine()
linalg = LinAlgEngine()
94 changes: 94 additions & 0 deletions arch/tests/experimental/test_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest

from arch.experimental.engine import (
_SUPPORTED_ENGINES,
backend,
linalg,
numpy,
set_backend,
use_backend,
)


def test_numpy_name():
for engine_name in _SUPPORTED_ENGINES:
if engine_name == "tensorflow":
continue

with use_backend(engine_name):
assert engine_name == numpy.name


def test_linalg_name():
for engine_name in _SUPPORTED_ENGINES:
if engine_name == "tensorflow":
continue

with use_backend(engine_name):
assert engine_name == linalg.name


def test_set_backend_eq():
assert set_backend(backend()) == backend()


def test_numpy_getattribute():
import numpy as np

with use_backend("numpy"):
array = numpy.array
assert array == np.array


def test_linalg_getattribute():
import numpy.linalg

with use_backend("numpy"):
inv = linalg.inv
assert inv == numpy.linalg.inv


def test_numpy_getattribute_failed_attr():
with use_backend("numpy"):
with pytest.raises(AttributeError) as exc:
numpy.xyz # noqa

assert str(exc.value) == (
"Cannot get attribute / function (xyz) from numpy library in "
"backend engine numpy"
)


def test_linalg_getattribute_failed_attr():
with use_backend("numpy"):
with pytest.raises(AttributeError) as exc:
linalg.xyz # noqa

assert str(exc.value) == (
"Cannot get attribute / function (xyz) from linalg library in "
"backend engine numpy"
)


def test_numpy_getattribute_failed_import():
with use_backend("jax"):
with pytest.raises(ImportError) as exc:
numpy.array # noqa

assert str(exc.value) == (
"Library `numpy` cannot be imported from backend engine "
"jax. Please make sure to install the library "
"via `pip install jax`."
)


def test_linalg_getattribute_failed_import():
with use_backend("jax"):
with pytest.raises(ImportError) as exc:
linalg.inv # noqa

assert str(exc.value) == (
"Library `linalg` cannot be imported from backend engine "
"jax. Please make sure to install the library "
"via `pip install jax`."
)
5 changes: 4 additions & 1 deletion arch/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
except ImportError: # pragma: no cover
pytestmark = pytest.mark.skip(reason=REASON)

SLOW_NOTEBOOKS = ["multiple-comparison_examples.ipynb"]
SLOW_NOTEBOOKS = [
"multiple-comparison_examples.ipynb",
"experimental_accelerated_numpy.ipynb",
]
if bool(os.environ.get("ARCH_TEST_SLOW_NOTEBOOKS", False)): # pragma: no cover
SLOW_NOTEBOOKS = []
kernel_name = "python%s" % sys.version_info.major
Expand Down
20 changes: 13 additions & 7 deletions arch/unitroot/unitroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
squeeze,
sum as npsum,
)
from numpy.linalg import LinAlgError, inv, lstsq, matrix_rank, pinv, qr, solve
from numpy.linalg import LinAlgError, inv, lstsq, matrix_rank, pinv, solve
from pandas import DataFrame
from scipy.stats import norm
from statsmodels.iolib.summary import Summary
from statsmodels.iolib.table import SimpleTable
from statsmodels.regression.linear_model import OLS, RegressionResults
from statsmodels.tsa.tsatools import lagmat

from arch.experimental import linalg as alinalg, numpy as anp
from arch.typing import (
ArrayLike,
ArrayLike1D,
Expand Down Expand Up @@ -337,7 +338,12 @@ def _autolag_ols(
max_lags=maxlag, lag=max(exog_rank - startlag, 0)
)
)
q, r = qr(exog)

endog = anp.asarray(endog)
exog = anp.asarray(exog)
q, r = alinalg.qr(exog)
# Convert it to 2-d so as to adapt to linalg.solve input format for all
# engines
qpy = q.T @ endog
ypy = endog.T @ endog
xpx = exog.T @ exog
Expand All @@ -347,12 +353,12 @@ def _autolag_ols(
nobs = float(endog.shape[0])
tstat[0] = inf
for i in range(startlag, startlag + maxlag + 1):
b = solve(r[:i, :i], qpy[:i])
sigma2[i - startlag] = squeeze(ypy - b.T @ xpx[:i, :i] @ b) / nobs
b = alinalg.solve(r[:i, :i], qpy[:i])
sigma2[i - startlag] = anp.squeeze(ypy - b.T @ xpx[:i, :i] @ b) / nobs
if lower_method == "t-stat" and i > startlag:
xpxi = inv(xpx[:i, :i])
stderr = sqrt(sigma2[i - startlag] * xpxi[-1, -1])
tstat[i - startlag] = squeeze(b[-1]) / stderr
xpxi = alinalg.inv(xpx[:i, :i])
stderr = anp.sqrt(sigma2[i - startlag] * xpxi[-1, -1])
tstat[i - startlag] = anp.squeeze(b[-1]) / stderr

return _select_best_ic(method, nobs, sigma2, tstat)

Expand Down
Loading