Skip to content

Commit

Permalink
LOFOImportance accepts groups params for GroupKFold cv (#57)
Browse files Browse the repository at this point in the history
* Add groups option

* Add groups option

* Update lofo_importance.py
  • Loading branch information
kingychiu committed Jan 16, 2024
1 parent aa0d224 commit 67d787a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lofo/lofo_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ class LOFOImportance:
Same as cv in sklearn API
n_jobs: int, optional
Number of jobs for parallel computation
cv_groups: array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into train/test set.
Only used in conjunction with a “Group” cv instance (e.g., GroupKFold).
"""

def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None):
def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=None, cv_groups=None):

self.fit_params = fit_params if fit_params else dict()
if model is None:
Expand All @@ -38,6 +41,7 @@ def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=N
self.dataset = dataset
self.scoring = scoring
self.cv = cv
self.cv_groups = cv_groups
self.n_jobs = n_jobs
if self.n_jobs is not None and self.n_jobs > 1:
warning_str = ("Warning: If your model is multithreaded, please initialise the number"
Expand All @@ -50,7 +54,7 @@ def _get_cv_score(self, feature_to_remove):

with warnings.catch_warnings():
warnings.simplefilter("ignore")
cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params)
cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params, groups=self.cv_groups)
return cv_results['test_score']

def _get_cv_score_parallel(self, feature, result_queue):
Expand Down

0 comments on commit 67d787a

Please sign in to comment.