Skip to content

Commit

Permalink
update for new sklearn version
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhao062@gmail.com authored and yzhao062@gmail.com committed Jul 16, 2023
1 parent 3235ab0 commit 2130bba
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions suod/test/test_model_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pyod.models.lscp import LSCP
from joblib import dump, load

from ..utils.utility import _get_sklearn_version


class TestModelSaveLoad(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -47,11 +49,19 @@ def setUp(self):

this_directory = os.path.abspath(os.path.dirname(__file__))

self.cost_forecast_loc_fit_ = os.path.join(this_directory,
'bps_train.joblib')
sklearn_version = _get_sklearn_version()
if sklearn_version[:3] >= '1.3':
self.cost_forecast_loc_fit_ = os.path.join(this_directory,
'bps_train.joblib')

self.cost_forecast_loc_pred_ = os.path.join(this_directory,
'bps_prediction.joblib')
else:
self.cost_forecast_loc_fit_ = os.path.join(this_directory,
'bps_train_old.joblib')

self.cost_forecast_loc_pred_ = os.path.join(this_directory,
'bps_prediction.joblib')
self.cost_forecast_loc_pred_ = os.path.join(this_directory,
'bps_prediction_old.joblib')

self.model = SUOD(base_estimators=self.base_estimators, n_jobs=2,
rp_flag_global=True, bps_flag=True,
Expand Down Expand Up @@ -81,7 +91,8 @@ def test_load(self):
model = load('model.joblib')

predicted_labels = model.predict(self.X_test) # predict labels
predicted_scores = model.decision_function(self.X_test) # predict scores
predicted_scores = model.decision_function(
self.X_test) # predict scores
predicted_probs = model.predict_proba(self.X_test) # predict scores

assert (len(predicted_labels) != 0)
Expand Down

0 comments on commit 2130bba

Please sign in to comment.