Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sensitivity Analysis #201

Merged
merged 55 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
85490aa
add basic sensitivity element structure
SvenKlaassen May 10, 2023
a42993d
add sensitivity elements to IRM
SvenKlaassen May 10, 2023
d1c8da6
add _est_sensitivity_elements method to all classes
SvenKlaassen May 11, 2023
5512ed2
add weights to sensitivity elements for IRM
SvenKlaassen May 11, 2023
2320bb7
add basic ci to sensitiviy analysis
SvenKlaassen May 11, 2023
84ea4b1
add make_dataset for bias bounds
SvenKlaassen May 12, 2023
338b855
update make_conf_irm_data
SvenKlaassen May 12, 2023
1ced4dd
add sensitivity elements to PLR
SvenKlaassen May 12, 2023
44c5a3c
fix sensitivity elements PLR and format
SvenKlaassen May 12, 2023
ba7ff3e
update NotImplementedErrors for sensitivity analysis
SvenKlaassen May 12, 2023
da5a60a
add contour plot
SvenKlaassen May 12, 2023
dfd08ed
add input checks to sensitivity analysis
SvenKlaassen May 12, 2023
90c5067
restructure sensivity analysis
SvenKlaassen May 15, 2023
2e18747
force theta to be a float
SvenKlaassen May 15, 2023
0b87e4f
add sensitivity summary
SvenKlaassen May 15, 2023
3f42c9f
rename make_confounded_irm_data
SvenKlaassen May 15, 2023
fb6e4dc
adapt for callable scores
SvenKlaassen May 15, 2023
32bf981
fix unit tests
SvenKlaassen May 15, 2023
02336c0
adapt sensitivity to external sample splitting
SvenKlaassen May 15, 2023
230e37f
fix contour plot hover_template
SvenKlaassen May 16, 2023
e388fdd
add effect heterogeneity
SvenKlaassen May 16, 2023
f258635
add input tests for sensitivity analysis
SvenKlaassen May 16, 2023
829503d
update unit tests
SvenKlaassen May 16, 2023
a8aed64
extend sensitivity unit tests for model defaults and return types
SvenKlaassen May 17, 2023
f7eeef5
add documentation to sensitvity analysis
SvenKlaassen May 17, 2023
8ad2a03
update dgps
SvenKlaassen May 17, 2023
5f9e84f
fix confounded plr dgp
SvenKlaassen May 17, 2023
b1c975d
Update datasets.py
SvenKlaassen May 22, 2023
69928dc
add unit tests for sensitivity analysis
SvenKlaassen May 22, 2023
9dc2a1b
change var estimation to _utils
SvenKlaassen May 22, 2023
0277c1b
update variance estimation (add clustering for sensitivity analysis)
SvenKlaassen May 23, 2023
e891ee7
fix format
SvenKlaassen May 23, 2023
fce877d
add basic unit test for clustering and sensitivity analysis
SvenKlaassen May 23, 2023
4613a69
fix plr sensitivity unit test
SvenKlaassen May 23, 2023
0adb028
extend unit tests
SvenKlaassen May 24, 2023
c5b35f7
restructure checks
SvenKlaassen May 24, 2023
968443f
adjust rr for plr
SvenKlaassen May 24, 2023
1b0cdbf
add g_hat1 to IRM and DID for sensitivity analysis
SvenKlaassen May 26, 2023
3f24947
add exceptions for neg var
SvenKlaassen May 26, 2023
08afa60
update PLR and IRM sensitivity_elements
SvenKlaassen May 26, 2023
41f6b53
did sensitivity bounds
SvenKlaassen May 26, 2023
3fb30af
update unit tests
SvenKlaassen May 27, 2023
f44f8ad
add tests for se if rho=0
SvenKlaassen May 30, 2023
1e1d212
remove psi_scaled from sensitivity_elements
SvenKlaassen May 30, 2023
c5c758b
add sensitivity elements for DID and DIDCS
SvenKlaassen May 30, 2023
63a5b2a
add unit tests for sensitivity_elements DID and DIDCS
SvenKlaassen May 30, 2023
5d4613c
update docstrings for confounded datasets
SvenKlaassen May 31, 2023
d6c0558
add the possiblity to remove the scenario from the plot
SvenKlaassen May 31, 2023
197597c
fix unit tests
SvenKlaassen May 31, 2023
1e58633
fix typo
SvenKlaassen Jun 5, 2023
a63baaf
fix bug in sensitivity summary
SvenKlaassen Jun 9, 2023
59945ec
Update api.rst
SvenKlaassen Jun 9, 2023
c797881
adjust for multiple h_0 and rename to null_hypothesis
SvenKlaassen Jun 9, 2023
ebdc682
fix format
SvenKlaassen Jun 9, 2023
ad636bd
Merge branch 'main' into s-bias-bounds
SvenKlaassen Jun 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Dataset generators
datasets.make_iivm_data
datasets.make_plr_turrell2018
datasets.make_pliv_multiway_cluster_CKMS2021
datasets.make_confounded_plr_data
datasets.make_confounded_irm_data


Score mixin classes for double machine learning models
------------------------------------------------------
Expand Down
233 changes: 89 additions & 144 deletions doubleml/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold, GridSearchCV, RandomizedSearchCV
from sklearn.metrics import mean_squared_error
from sklearn.utils.multiclass import type_of_target

from statsmodels.nonparametric.kde import KDEUnivariate

from joblib import Parallel, delayed

from ._utils_checks import _check_is_partition


def _assure_2d_array(x):
if x.ndim == 1:
Expand Down Expand Up @@ -40,63 +41,6 @@ def _get_cond_smpls_2d(smpls, bin_var1, bin_var2):
return smpls_00, smpls_01, smpls_10, smpls_11


def _check_is_partition(smpls, n_obs):
test_indices = np.concatenate([test_index for _, test_index in smpls])
if len(test_indices) != n_obs:
return False
hit = np.zeros(n_obs, dtype=bool)
hit[test_indices] = True
if not np.all(hit):
return False
return True


def _check_all_smpls(all_smpls, n_obs, check_intersect=False):
all_smpls_checked = list()
for smpl in all_smpls:
all_smpls_checked.append(_check_smpl_split(smpl, n_obs, check_intersect))
return all_smpls_checked


def _check_smpl_split(smpl, n_obs, check_intersect=False):
smpl_checked = list()
for tpl in smpl:
smpl_checked.append(_check_smpl_split_tpl(tpl, n_obs, check_intersect))
return smpl_checked


def _check_smpl_split_tpl(tpl, n_obs, check_intersect=False):
train_index = np.sort(np.array(tpl[0]))
test_index = np.sort(np.array(tpl[1]))

if not issubclass(train_index.dtype.type, np.integer):
raise TypeError('Invalid sample split. Train indices must be of type integer.')
if not issubclass(test_index.dtype.type, np.integer):
raise TypeError('Invalid sample split. Test indices must be of type integer.')

if check_intersect:
if set(train_index) & set(test_index):
raise ValueError('Invalid sample split. Intersection of train and test indices is not empty.')

if len(np.unique(train_index)) != len(train_index):
raise ValueError('Invalid sample split. Train indices contain non-unique entries.')
if len(np.unique(test_index)) != len(test_index):
raise ValueError('Invalid sample split. Test indices contain non-unique entries.')

# we sort the indices above
# if not np.all(np.diff(train_index) > 0):
# raise NotImplementedError('Invalid sample split. Only sorted train indices are supported.')
# if not np.all(np.diff(test_index) > 0):
# raise NotImplementedError('Invalid sample split. Only sorted test indices are supported.')

if not set(train_index).issubset(range(n_obs)):
raise ValueError('Invalid sample split. Train indices must be in [0, n_obs).')
if not set(test_index).issubset(range(n_obs)):
raise ValueError('Invalid sample split. Test indices must be in [0, n_obs).')

return train_index, test_index


def _fit(estimator, x, y, train_index, idx=None):
estimator.fit(x[train_index, :], y[train_index])
return estimator, idx
Expand Down Expand Up @@ -238,13 +182,6 @@ def _draw_weights(method, n_rep_boot, n_obs):
return weights


def _check_finite_predictions(preds, learner, learner_name, smpls):
test_indices = np.concatenate([test_index for _, test_index in smpls])
if not np.all(np.isfinite(preds[test_indices])):
raise ValueError(f'Predictions from learner {str(learner)} for {learner_name} are not finite.')
return


def _trimm(preds, trimming_rule, trimming_threshold):
if trimming_rule == 'truncate':
preds[preds < trimming_threshold] = trimming_threshold
Expand All @@ -261,14 +198,6 @@ def _normalize_ipw(propensity, treatment):
return normalized_weights


def _check_is_propensity(preds, learner, learner_name, smpls, eps=1e-12):
test_indices = np.concatenate([test_index for _, test_index in smpls])
if any((preds[test_indices] < eps) | (preds[test_indices] > 1 - eps)):
warnings.warn(f'Propensity predictions from learner {str(learner)} for'
f' {learner_name} are close to zero or one (eps={eps}).')
return


def _rmse(y_true, y_pred):
subset = np.logical_not(np.isnan(y_true))
rmse = mean_squared_error(y_true[subset], y_pred[subset], squared=False)
Expand All @@ -285,77 +214,6 @@ def _predict_zero_one_propensity(learner, X):
return res


def _check_contains_iv(obj_dml_data):
if obj_dml_data.z_cols is not None:
raise ValueError('Incompatible data. ' +
' and '.join(obj_dml_data.z_cols) +
' have been set as instrumental variable(s). '
'To fit an local model see the documentation.')


def _check_zero_one_treatment(obj_dml):
one_treat = (obj_dml._dml_data.n_treat == 1)
binary_treat = (type_of_target(obj_dml._dml_data.d) == 'binary')
zero_one_treat = np.all((np.power(obj_dml._dml_data.d, 2) - obj_dml._dml_data.d) == 0)
if not (one_treat & binary_treat & zero_one_treat):
raise ValueError('Incompatible data. '
f'To fit an {str(obj_dml.score)} model with DML '
'exactly one binary variable with values 0 and 1 '
'needs to be specified as treatment variable.')


def _check_quantile(quantile):
if not isinstance(quantile, float):
raise TypeError('Quantile has to be a float. ' +
f'Object of type {str(type(quantile))} passed.')

if (quantile <= 0) | (quantile >= 1):
raise ValueError('Quantile has be between 0 or 1. ' +
f'Quantile {str(quantile)} passed.')
return


def _check_treatment(treatment):
if not isinstance(treatment, int):
raise TypeError('Treatment indicator has to be an integer. ' +
f'Object of type {str(type(treatment))} passed.')

if (treatment != 0) & (treatment != 1):
raise ValueError('Treatment indicator has be either 0 or 1. ' +
f'Treatment indicator {str(treatment)} passed.')
return


def _check_trimming(trimming_rule, trimming_threshold):
valid_trimming_rule = ['truncate']
if trimming_rule not in valid_trimming_rule:
raise ValueError('Invalid trimming_rule ' + str(trimming_rule) + '. ' +
'Valid trimming_rule ' + ' or '.join(valid_trimming_rule) + '.')
if not isinstance(trimming_threshold, float):
raise TypeError('trimming_threshold has to be a float. ' +
f'Object of type {str(type(trimming_threshold))} passed.')
if (trimming_threshold <= 0) | (trimming_threshold >= 0.5):
raise ValueError('Invalid trimming_threshold ' + str(trimming_threshold) + '. ' +
'trimming_threshold has to be between 0 and 0.5.')
return


def _check_score(score, valid_score, allow_callable=True):
if isinstance(score, str):
if score not in valid_score:
raise ValueError('Invalid score ' + score + '. ' +
'Valid score ' + ' or '.join(valid_score) + '.')
else:
if allow_callable:
if not callable(score):
raise TypeError('score should be either a string or a callable. '
'%r was passed.' % score)
else:
raise TypeError('score should be a string. '
'%r was passed.' % score)
return


def _get_bracket_guess(score, coef_start, coef_bounds):
max_bracket_length = coef_bounds[1] - coef_bounds[0]
b_guess = coef_bounds
Expand Down Expand Up @@ -388,3 +246,90 @@ def abs_ipw_score(theta):
method='brent')
ipw_est = res.x
return ipw_est


def _aggregate_coefs_and_ses(all_coefs, all_ses, var_scaling_factor):
# aggregation is done over dimension 1, such that the coefs and ses have to be of shape (n_coefs, n_rep)
n_rep = all_coefs.shape[1]
coefs = np.median(all_coefs, 1)

xx = np.tile(coefs.reshape(-1, 1), n_rep)
ses = np.sqrt(np.divide(np.median(np.multiply(np.power(all_ses, 2), var_scaling_factor) +
np.power(all_coefs - xx, 2), 1), var_scaling_factor))

return coefs, ses


def _var_est(psi, psi_deriv, apply_cross_fitting, smpls, is_cluster_data,
cluster_vars=None, smpls_cluster=None, n_folds_per_cluster=None):

if not is_cluster_data:
# psi and psi_deriv should be of shape (n_obs, ...)
if apply_cross_fitting:
var_scaling_factor = psi.shape[0]
else:
# In case of no-cross-fitting, the score function was only evaluated on the test data set
test_index = smpls[0][1]
psi_deriv = psi_deriv[test_index]
psi = psi[test_index]
var_scaling_factor = len(test_index)

J = np.mean(psi_deriv)
gamma_hat = np.mean(np.square(psi))

else:
assert cluster_vars is not None
assert smpls_cluster is not None
assert n_folds_per_cluster is not None
n_folds = len(smpls)

# one cluster
if cluster_vars.shape[1] == 1:
first_cluster_var = cluster_vars[:, 0]
clusters = np.unique(first_cluster_var)
gamma_hat = 0
j_hat = 0
for i_fold in range(n_folds):
test_inds = smpls[i_fold][1]
test_cluster_inds = smpls_cluster[i_fold][1]
I_k = test_cluster_inds[0]
const = 1 / len(I_k)
for cluster_value in I_k:
ind_cluster = (first_cluster_var == cluster_value)
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
j_hat += np.sum(psi_deriv[test_inds]) / len(I_k)

var_scaling_factor = len(clusters)
J = np.divide(j_hat, n_folds_per_cluster)
gamma_hat = np.divide(gamma_hat, n_folds_per_cluster)

else:
assert cluster_vars.shape[1] == 2
first_cluster_var = cluster_vars[:, 0]
second_cluster_var = cluster_vars[:, 1]
gamma_hat = 0
j_hat = 0
for i_fold in range(n_folds):
test_inds = smpls[i_fold][1]
test_cluster_inds = smpls_cluster[i_fold][1]
I_k = test_cluster_inds[0]
J_l = test_cluster_inds[1]
const = np.divide(min(len(I_k), len(J_l)), (np.square(len(I_k) * len(J_l))))
for cluster_value in I_k:
ind_cluster = (first_cluster_var == cluster_value) & np.in1d(second_cluster_var, J_l)
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
for cluster_value in J_l:
ind_cluster = (second_cluster_var == cluster_value) & np.in1d(first_cluster_var, I_k)
gamma_hat += const * np.sum(np.outer(psi[ind_cluster], psi[ind_cluster]))
j_hat += np.sum(psi_deriv[test_inds]) / (len(I_k) * len(J_l))

n_first_clusters = len(np.unique(first_cluster_var))
n_second_clusters = len(np.unique(second_cluster_var))
var_scaling_factor = min(n_first_clusters, n_second_clusters)
J = np.divide(j_hat, np.square(n_folds_per_cluster))
gamma_hat = np.divide(gamma_hat, np.square(n_folds_per_cluster))

scaling = np.divide(1.0, np.multiply(var_scaling_factor, np.square(J)))
sigma2_hat = np.multiply(scaling, gamma_hat)

return sigma2_hat, var_scaling_factor
Loading