Skip to content

Commit

Permalink
Merge pull request #249 from DoubleML/s-add-sensitivity-framework
Browse files Browse the repository at this point in the history
Add sensitivity analysis to framework class
  • Loading branch information
SvenKlaassen authored Aug 12, 2024
2 parents 5f19bd8 + 399decf commit 20d9864
Show file tree
Hide file tree
Showing 15 changed files with 1,139 additions and 361 deletions.
4 changes: 3 additions & 1 deletion doubleml/did/did.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
4 changes: 3 additions & 1 deletion doubleml/did/did_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
330 changes: 86 additions & 244 deletions doubleml/double_ml.py

Large diffs are not rendered by default.

574 changes: 533 additions & 41 deletions doubleml/double_ml_framework.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion doubleml/irm/irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ def _sensitivity_element_est(self, preds):
element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
10 changes: 7 additions & 3 deletions doubleml/plm/plr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,17 @@ def _sensitivity_element_est(self, preds):
sigma2 = np.mean(sigma2_score_element)
psi_sigma2 = sigma2_score_element - sigma2

nu2 = np.divide(1.0, np.mean(np.square(d - m_hat)))
psi_nu2 = nu2 - np.multiply(np.square(d-m_hat), np.square(nu2))
treatment_residual = d - m_hat
nu2 = np.divide(1.0, np.mean(np.square(treatment_residual)))
psi_nu2 = nu2 - np.multiply(np.square(treatment_residual), np.square(nu2))
rr = np.multiply(treatment_residual, nu2)

element_dict = {'sigma2': sigma2,
'nu2': nu2,
'psi_sigma2': psi_sigma2,
'psi_nu2': psi_nu2}
'psi_nu2': psi_nu2,
'riesz_rep': rr,
}
return element_dict

def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv,
Expand Down
8 changes: 6 additions & 2 deletions doubleml/tests/_utils_doubleml_sensitivity_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def doubleml_sensitivity_manual(sensitivity_elements, all_coefs, psi, psi_deriv,
theta_upper, sigma_upper = _aggregate_coefs_and_ses(all_theta_upper, all_sigma_upper, var_scaling_factor)

quant = norm.ppf(level)
ci_lower = theta_lower - np.multiply(quant, sigma_lower)
ci_upper = theta_upper + np.multiply(quant, sigma_upper)

all_ci_lower = all_theta_lower - np.multiply(quant, all_sigma_lower)
all_ci_upper = all_theta_upper + np.multiply(quant, all_sigma_upper)

ci_lower = np.median(all_ci_lower, axis=1)
ci_upper = np.median(all_ci_upper, axis=1)

theta_dict = {'lower': theta_lower,
'upper': theta_upper}
Expand Down
25 changes: 25 additions & 0 deletions doubleml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from doubleml.datasets import make_plr_turrell2018, make_irm_data, \
make_pliv_CHS2015

from doubleml import DoubleMLData


def _g(x):
return np.power(np.sin(x), 2)
Expand All @@ -22,6 +24,29 @@ def _m2(x):
return np.power(x, 2)


@pytest.fixture(scope='session',
params=[(500, 5)])
def generate_data_simple(request):
n_p = request.param
np.random.seed(1111)
# setting parameters
n = n_p[0]
p = n_p[1]
theta = 1.0

# generating data
D1 = 1.0 * (np.random.uniform(size=n) > 0.5)
D2 = 1.0 * (np.random.uniform(size=n) > 0.5)
X = np.random.normal(size=(n, p))
Y = theta * D1 + np.dot(X, np.ones(p)) + np.random.normal(size=n)
df = pd.DataFrame(np.column_stack((X, Y, D1, D2)),
columns=[f'X{i + 1}' for i in np.arange(p)] + ['Y', 'D1', 'D2'])
data_d1 = DoubleMLData(df, 'Y', 'D1')
data_d2 = DoubleMLData(df, 'Y', 'D2')

return data_d1, data_d2


@pytest.fixture(scope='session',
params=[(500, 10),
(1000, 20),
Expand Down
52 changes: 4 additions & 48 deletions doubleml/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,10 +1008,11 @@ def test_doubleml_sensitivity_not_yet_implemented():

dml_pliv = DoubleMLPLIV(dml_data_pliv, ml_g, ml_m, ml_r)
dml_pliv.fit()
msg = "Sensitivity analysis not yet implemented for DoubleMLPLIV."
msg = 'Sensitivity analysis is not implemented for this model.'
with pytest.raises(NotImplementedError, match=msg):
_ = dml_pliv.sensitivity_analysis()

msg = 'Sensitivity analysis not yet implemented for DoubleMLPLIV.'
with pytest.raises(NotImplementedError, match=msg):
_ = dml_pliv.sensitivity_benchmark(benchmarking_set=["X1"])

Expand All @@ -1025,77 +1026,45 @@ def test_doubleml_sensitivity_inputs():
msg = "cf_y must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=1, cf_d=0.03, rho=1.0, level=0.95)

msg = r'cf_y must be in \[0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=1.0, cf_d=0.03, rho=1.0, level=0.95)

# test cf_d
msg = "cf_d must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=1, rho=1.0, level=0.95)

msg = r'cf_d must be in \[0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=1.0, rho=1.0, level=0.95)

# test rho
msg = "rho must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1, level=0.95)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1, null_hypothesis=0.0, level=0.95, idx_treatment=0)

msg = "rho must be of float type. 1 of type <class 'str'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1")
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho="1", level=0.95)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho="1", null_hypothesis=0.0, level=0.95, idx_treatment=0)

msg = r'The absolute value of rho must be in \[0,1\]. 1.1 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.1, level=0.95)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.1, null_hypothesis=0.0, level=0.95, idx_treatment=0)

# test level
msg = "The confidence level must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1)
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=1, null_hypothesis=0.0, idx_treatment=0)

msg = r'The confidence level must be in \(0,1\). 1.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=1.0, null_hypothesis=0.0, idx_treatment=0)

msg = r'The confidence level must be in \(0,1\). 0.0 was passed.'
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_sensitivity_analysis(cf_y=0.1, cf_d=0.15, rho=1.0, level=0.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(rho=1.0, level=0.0, null_hypothesis=0.0, idx_treatment=0)

# test null_hypothesis
msg = "null_hypothesis has to be of type float or np.ndarry. 1 of type <class 'int'> was passed."
Expand All @@ -1104,30 +1073,18 @@ def test_doubleml_sensitivity_inputs():
msg = r"null_hypothesis is numpy.ndarray but does not have the required shape \(1,\). Array of shape \(2,\) was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_analysis(null_hypothesis=np.array([1, 2]))
msg = "null_hypothesis must be of float type. 1 of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(null_hypothesis=1, level=0.95, rho=1.0, idx_treatment=0)
msg = r"null_hypothesis must be of float type. \[1\] of type <class 'numpy.ndarray'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(null_hypothesis=np.array([1]), level=0.95, rho=1.0, idx_treatment=0)

# test idx_treatment
dml_irm.sensitivity_analysis()
msg = "idx_treatment must be an integer. 0.0 of type <class 'float'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=0.0, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(TypeError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=0.0)

msg = "idx_treatment must be larger or equal to 0. -1 was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=-1, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=-1)

msg = "idx_treatment must be smaller or equal to 0. 1 was passed."
with pytest.raises(ValueError, match=msg):
_ = dml_irm._calc_robustness_value(idx_treatment=1, null_hypothesis=0.0, level=0.95, rho=1.0)
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot(idx_treatment=1)

Expand All @@ -1142,7 +1099,7 @@ def test_doubleml_sensitivity_inputs():
_ = dml_irm._set_sensitivity_elements(sensitivity_elements=sensitivity_elements, i_rep=0, i_treat=0)

# test variances
sensitivity_elements = dict({'sigma2': 1.0, 'nu2': -2.4, 'psi_sigma2': 1.0, 'psi_nu2': 1.0})
sensitivity_elements = dict({'sigma2': 1.0, 'nu2': -2.4, 'psi_sigma2': 1.0, 'psi_nu2': 1.0, 'riesz_rep': 1.0})
_ = dml_irm._set_sensitivity_elements(sensitivity_elements=sensitivity_elements, i_rep=0, i_treat=0)
msg = ('sensitivity_elements sigma2 and nu2 have to be positive. '
r'Got sigma2 \[\[\[1.\]\]\] and nu2 \[\[\[-2.4\]\]\]. '
Expand Down Expand Up @@ -1176,8 +1133,7 @@ def test_doubleml_sensitivity_plot_input():
dml_irm = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression(), trimming_threshold=0.1)
dml_irm.fit()

msg = (r'Apply sensitivity_analysis\(\) to include senario in sensitivity_plot. '
'The values of rho and the level are used for the scenario.')
msg = (r'Apply sensitivity_analysis\(\) to include senario in sensitivity_plot. ')
with pytest.raises(ValueError, match=msg):
_ = dml_irm.sensitivity_plot()

Expand Down
26 changes: 16 additions & 10 deletions doubleml/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def dml_framework_from_doubleml_fixture(n_rep):
ml_g = LinearRegression()
ml_m = LogisticRegression()

dml_irm_obj = DoubleMLIRM(dml_data, ml_g, ml_m)
dml_irm_obj = DoubleMLIRM(dml_data, ml_g, ml_m, n_rep=n_rep)
dml_irm_obj.fit()
dml_framework_obj = dml_irm_obj.construct_framework()

Expand All @@ -179,7 +179,7 @@ def dml_framework_from_doubleml_fixture(n_rep):

# substract objects
dml_data_2 = make_irm_data()
dml_irm_obj_2 = DoubleMLIRM(dml_data_2, ml_g, ml_m)
dml_irm_obj_2 = DoubleMLIRM(dml_data_2, ml_g, ml_m, n_rep=n_rep)
dml_irm_obj_2.fit()
dml_framework_obj_2 = dml_irm_obj_2.construct_framework()

Expand Down Expand Up @@ -218,6 +218,7 @@ def dml_framework_from_doubleml_fixture(n_rep):
'ci_joint_sub_obj': ci_joint_sub_obj,
'ci_joint_mul_obj': ci_joint_mul_obj,
'ci_joint_concat': ci_joint_concat,
'n_rep': n_rep,
}
return result_dict

Expand Down Expand Up @@ -257,14 +258,19 @@ def test_dml_framework_from_doubleml_se(dml_framework_from_doubleml_fixture):
dml_framework_from_doubleml_fixture['dml_framework_obj_add_obj'].all_ses,
2*dml_framework_from_doubleml_fixture['dml_obj'].all_se
)
scaling = np.array([dml_framework_from_doubleml_fixture['dml_obj']._var_scaling_factors]).reshape(-1, 1)
sub_var = np.mean(
np.square(dml_framework_from_doubleml_fixture['dml_obj'].psi - dml_framework_from_doubleml_fixture['dml_obj_2'].psi),
axis=0)
assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_sub_obj'].all_ses,
np.sqrt(sub_var / scaling)
)

if dml_framework_from_doubleml_fixture['n_rep'] == 1:
# formula only valid for n_rep = 1
scaling = np.array([dml_framework_from_doubleml_fixture['dml_obj']._var_scaling_factors]).reshape(-1, 1)
sub_var = np.mean(
np.square(dml_framework_from_doubleml_fixture['dml_obj'].psi
- dml_framework_from_doubleml_fixture['dml_obj_2'].psi),
axis=0)
assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_sub_obj'].all_ses,
np.sqrt(sub_var / scaling)
)

assert np.allclose(
dml_framework_from_doubleml_fixture['dml_framework_obj_mul_obj'].all_ses,
2*dml_framework_from_doubleml_fixture['dml_obj'].all_se
Expand Down
Loading

0 comments on commit 20d9864

Please sign in to comment.