From 3b567a378592209d7cdd26415761827a06338773 Mon Sep 17 00:00:00 2001 From: Stefano Ariestasia Date: Sat, 17 Aug 2024 23:33:43 +0900 Subject: [PATCH] add EMA as mode to SSLChannels --- technical/indicators/indicators.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/technical/indicators/indicators.py b/technical/indicators/indicators.py index 54c3797d..0c2321c8 100644 --- a/technical/indicators/indicators.py +++ b/technical/indicators/indicators.py @@ -1094,22 +1094,27 @@ def SSLChannels(dataframe, length=10, mode="sma"): Usage: dataframe['sslDown'], dataframe['sslUp'] = SSLChannels(dataframe, 10) """ - if mode not in ("sma"): + import talib.abstract as ta + + mode_lower = mode.lower() + + if mode_lower not in ("sma", "ema"): raise ValueError(f"Mode {mode} not supported yet") df = dataframe.copy() - if mode == "sma": - df["smaHigh"] = df["high"].rolling(length).mean() - df["smaLow"] = df["low"].rolling(length).mean() + if mode_lower == "sma": + ma_high = df["high"].rolling(length).mean() + ma_low = df["low"].rolling(length).mean() + elif mode_lower == "ema": + ma_high = ta.EMA(df["high"], length) + ma_low = ta.EMA(df["low"], length) - df["hlv"] = np.where( - df["close"] > df["smaHigh"], 1, np.where(df["close"] < df["smaLow"], -1, np.NAN) - ) + df["hlv"] = np.where(df["close"] > ma_high, 1, np.where(df["close"] < ma_low, -1, np.NAN)) df["hlv"] = df["hlv"].ffill() - df["sslDown"] = np.where(df["hlv"] < 0, df["smaHigh"], df["smaLow"]) - df["sslUp"] = np.where(df["hlv"] < 0, df["smaLow"], df["smaHigh"]) + df["sslDown"] = np.where(df["hlv"] < 0, ma_high, ma_low) + df["sslUp"] = np.where(df["hlv"] < 0, ma_low, ma_high) return df["sslDown"], df["sslUp"]