Skip to content

Commit

Permalink
add more comprehensive tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMBury committed Aug 4, 2023
1 parent 6764ee9 commit 60a530a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
54 changes: 53 additions & 1 deletion ewstools/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, data, transition=None):
print(
"Make sure to provide data as either a list, np.ndarray or pd.Series\n"
)
return

# Set state and transition attributes
self.state = df_state
Expand Down Expand Up @@ -1129,6 +1130,7 @@ def __init__(self, data, transition=None):
if type(data) != pd.DataFrame:
print("\nERROR: data has been provided as type {}".format(type(data)))
print("Please provide data as a pandas DataFrame.\n")
return

# Set state and transition attributes
self.state = data
Expand Down Expand Up @@ -1224,7 +1226,8 @@ def compute_covar(self, rolling_window=0.25, leading_eval=False):
If residuals have not been computed, computation will be
performed over state variable.
Put into 'ews' dataframe
Put covariance matrices into self.covar
Put leading eigenvalue into self.ews
Parameters
----------
Expand Down Expand Up @@ -1282,6 +1285,55 @@ def compute_covar(self, rolling_window=0.25, leading_eval=False):
series_evals = pd.Series(ar_evals, index=df_pre.index)
self.ews["covar_leading_eval"] = series_evals

def compute_corr(self, rolling_window=0.25):
"""
Compute the (Pearson) correlation matrix over a rolling window.
If residuals have not been computed, computation will be
performed over state variable.
Put correlation matrices into self.corr
Parameters
----------
rolling_window : float
Length of rolling window used to compute variance. Can be specified
as an absolute value or as a proportion of the length of the
data being analysed. Default is 0.25.
Returns
-------
None.
"""

# Get time series data prior to transition
if self.transition:
df_pre = self.state[self.state.index <= self.transition]
else:
df_pre = self.state

# Get absolute size of rollling window if given as a proportion
if 0 < rolling_window <= 1:
rw_absolute = int(rolling_window * len(df_pre))
else:
rw_absolute = rolling_window

# If residuals column exists, compute over residuals.
if "{}_residuals".format(self.var_names[0]) in df_pre.columns:
col_names_to_compute = [
"{}_residuals".format(var) for var in self.var_names
]
else:
col_names_to_compute = self.var_names

# Compute correlation matrix
df_corr = (
df_pre[col_names_to_compute]
.rolling(window=rw_absolute)
.corr(method="Pearson")
)
self.corr = df_corr


# -----------------------------
# Eigenvalue reconstruction
Expand Down
13 changes: 12 additions & 1 deletion tests/test_ewstools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,22 @@ def test_MultiTimeSeries_ews():
df.index.name = "index_name"

# Create MultiTimeSeries object
mts = ewstools.core.MultiTimeSeries(df, transition=8)
mts = ewstools.MultiTimeSeries(np.array([1, 2, 3])) # invalid entry
mts = ewstools.MultiTimeSeries(df, transition=8)
mts.detrend(method="XXX", bandwidth=0.2) # invalid detrend method
mts.detrend(method="Gaussian", bandwidth=0.2)
mts.detrend(method="Gaussian", bandwidth=20)
mts.detrend(method="Lowess", span=0.2)
mts.detrend(method="Lowess", span=20)

mts.compute_covar(rolling_window=0.25, leading_eval=True)
mts.compute_covar(rolling_window=20, leading_eval=False)
mts.compute_corr(rolling_window=0.25)
mts.compute_corr(rolling_window=20)

assert type(mts.ews) == pd.DataFrame
assert type(mts.covar) == pd.DataFrame
assert type(mts.corr) == pd.DataFrame
assert "x_residuals" in mts.state.columns
assert "z_smoothing" in mts.state.columns
assert "covar_leading_eval" in mts.ews.columns
Expand All @@ -81,6 +91,7 @@ def test_TimeSeries_init():
xVals = 5 + np.random.normal(0, 1, len(tVals))

# Create TimeSeries object using np.ndarray
ts = ewstools.TimeSeries("hello") # invalid entry
ts = ewstools.TimeSeries(xVals)
assert type(ts.state) == pd.DataFrame
assert type(ts.ews) == pd.DataFrame
Expand Down

0 comments on commit 60a530a

Please sign in to comment.