Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled feature_importances_ for our ForestDML and ForestDRLearner estimators #306

Merged
merged 17 commits into from
Nov 9, 2020

Conversation

vsyrgkanis
Copy link
Collaborator

@vsyrgkanis vsyrgkanis commented Nov 7, 2020

This required changing the subsampled honest forest code a bit so that it does not alter the arrays of the tree structures of sklearn but rather stores two additional arrays required for prediction. This does add around 1.5 times the original running time, so makes it slightly slower due to the extra memory allocation.

However this enables correct feature_importance calculation and also in the future correct SHAP calculation (fixes #297), as now the tree entries are consistent with a tree in a randomforestregressor and so shap logic can be applied if we recast the subsampled honest forest as a randomforestregressor (additivity of shap will still be violated since the prediction of the subsample honest forest is not just the aggregation of the predictions across the trees but more complex weighted average). But we can still call shap and still get meaningful shap numbers. One discrepancy is that shap is explaining a different value that what effect returns, since it explains the value that corresponds to the average of the predictions of each honest tree regressor. however, the prediction of an honest forest is not the average of the tree predictions. For a full solution to this small discrepancy, one would need a full re-working of Shap's tree explainer and the tree explainer algorithm to account for such alternative aggregations of tree predictors.

This enables the following two uses:

from econml.dml import ForestDMLCateEstimator
import shap
import sklearn.ensemble
import copy
est3 = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              n_estimators=1000,
                              subsample_fr=.8,
                              min_samples_leaf=10,
                              min_impurity_decrease=0.001,
                              verbose=0, min_weight_fraction_leaf=.01)
est3.fit(Y, T, X, W)
print(est3.feature_importances_)
model = copy.deepcopy(est3.model_cate)
model.__class__ = sklearn.ensemble.RandomForestRegressor
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
shap.plots.beeswarm(shap_values, X)
from econml.drlearner import ForestDRLearner
import shap
import sklearn.ensemble
import copy
est3 = ForestDRLearner(model_regression=RandomForestRegressor(),
                       model_propensity=RandomForestClassifier(min_samples_leaf=10),
                       min_propensity=1e-3,
                       n_estimators=1000,
                       subsample_fr=.8,
                       min_samples_leaf=10,
                       min_impurity_decrease=0.001,
                       verbose=0, min_weight_fraction_leaf=.01)
est3.fit(Y, T, X, W)
for t in np.unique(T):
    if t > 0:
        print(est3.feature_importances_(T=t))
        model = copy.deepcopy(est3.model_cate(T=t))
        model.__class__ = sklearn.ensemble.RandomForestRegressor
        explainer = shap.Explainer(model, X)
        shap_values = explainer(X)
        shap.plots.beeswarm(shap_values, X)

Side benefits:

  1. 6x speed up of SubsampledHonestForest by pre-transforming dense but sparsely represented matrices to dense representation before a loop with many slicing operations on them.

… but rather create auxiliary numpy arrays that store the numerator and denominator of every node. This enables consistent feature_importance calculation and also potentially more accurate shap_values calcualtion.
@vsyrgkanis vsyrgkanis added the enhancement New feature or request label Nov 7, 2020
…ure_importances_. Added tests that the feature_importances_ API is working in test_drlearner and test_dml.
econml/dml.py Outdated Show resolved Hide resolved
econml/drlearner.py Outdated Show resolved Hide resolved
econml/sklearn_extensions/ensemble.py Outdated Show resolved Hide resolved
vsyrgkanis and others added 5 commits November 8, 2020 00:58
Co-authored-by: Keith Battocchi <kebatt@microsoft.com>
…ree level was causing trouble, since due to sample splitting feature_improtance can many times be negative (increase in variance) due to honesty and sample splitting. Now averaging the un-normalized feature importance. There is still a small caveat in the current version of how we use impurity. Added that as a TODO.
econml/sklearn_extensions/ensemble.py Outdated Show resolved Hide resolved
…orest, that now makes feature_importances_ exactly correct and no need to re-implement the method. Now impurities are computed on the estimation sample and replacing the pre-calculated node impurities.
…rallel_add_trees_ of ensemble.py. This leads to 6 fold speed-up as we were doing many slicing operations to sparse matrices before, which are very slow!
Copy link
Collaborator

@kbattocchi kbattocchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks fine, though I'm not really familiar with the tree internals to vouch for correctness. Please consider addressing the comments I left before merging.

econml/sklearn_extensions/ensemble.py Outdated Show resolved Hide resolved
econml/sklearn_extensions/ensemble.py Outdated Show resolved Hide resolved
econml/sklearn_extensions/ensemble.py Outdated Show resolved Hide resolved
@vsyrgkanis vsyrgkanis merged commit 61cd136 into master Nov 9, 2020
@vsyrgkanis vsyrgkanis deleted the vasilis/feature_importances branch November 16, 2020 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Shap interpretability with SubsampledHonestForest
3 participants