Skip to content

Commit

Permalink
update models to dict as well
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 2, 2024
1 parent 4404fda commit 14544c7
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 120 deletions.
21 changes: 10 additions & 11 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Again, required nuisance functions are fitted and stored in the cache.

## Specifying Models

By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is essentially a `NamedTuple` mapping variables' names to their respective model.
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is a `Dict{Symbol, Model}` mapping variables' names to their respective model.

Rather than specifying a specific model for each variable it may be easier to override the default models using the `default_models` function:

Expand All @@ -121,9 +121,9 @@ using MLJXGBoostInterface
xgboost_regressor = XGBoostRegressor()
xgboost_classifier = XGBoostClassifier()
models = default_models(
Q_binary=xgboost_classifier,
Q_continuous=xgboost_regressor,
G=xgboost_classifier
Q_binary = xgboost_classifier,
Q_continuous = xgboost_regressor,
G = xgboost_classifier
)
tmle_gboost = TMLEE(models=models)
```
Expand All @@ -140,19 +140,18 @@ stack_binary = Stack(
lr=lr
)
models = (
T₁ = with_encoder(xgboost_classifier), # T₁ with XGBoost prepended with a Continuous Encoder
default_models( # For all other variables use the following defaults
Q_binary=stack_binary, # A Super Learner
Q_continuous=xgboost_regressor, # An XGBoost
models = default_models( # For all non-specified variables use the following defaults
Q_binary = stack_binary, # A Super Learner
Q_continuous = xgboost_regressor, # An XGBoost
# T₁ with XGBoost prepended with a Continuous Encoder
T₁ = xgboost_classifier
# Unspecified G defaults to Logistic Regression
)...
)
tmle_custom = TMLEE(models=models)
```

Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `NamedTuple`.
Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `Dict`.

## CV-Estimation

Expand Down
12 changes: 6 additions & 6 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Base.showerror(io::IO, e::FitFailedError) = print(io, e.msg)

struct CMRelevantFactorsEstimator <: Estimator
resampling::Union{Nothing, ResamplingStrategy}
models::NamedTuple
models::Dict
end

CMRelevantFactorsEstimator(;models, resampling=nothing) =
Expand Down Expand Up @@ -152,7 +152,7 @@ end
#####################################################################

mutable struct TMLEE <: Estimator
models::NamedTuple
models::Dict
resampling::Union{Nothing, ResamplingStrategy}
ps_lowerbound::Union{Float64, Nothing}
weighted::Bool
Expand All @@ -168,7 +168,7 @@ function that can be applied to estimate estimands for a dataset.
# Arguments
- models: A NamedTuple{variables}(models) where the `variables` are the outcome variables modeled by the `models`.
- models: A Dict(variable => model, ...) where the `variables` are the outcome variables modeled by the `models`.
- resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla TMLE while
any valid `MLJ.ResamplingStrategy` will result in CV-TMLE.
- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will
Expand Down Expand Up @@ -237,7 +237,7 @@ gradient_and_estimate(::TMLEE, Ψ, factors, dataset; ps_lowerbound=1e-8) =
#####################################################################

mutable struct OSE <: Estimator
models::NamedTuple
models::Dict
resampling::Union{Nothing, ResamplingStrategy}
ps_lowerbound::Union{Float64, Nothing}
machine_cache::Bool
Expand All @@ -251,7 +251,7 @@ function that can be applied to estimate estimands for a dataset.
# Arguments
- models: A NamedTuple{variables}(models) where the `variables` are the outcome variables modeled by the `models`.
- models: A Dict(variable => model, ...) where the `variables` are the outcome variables modeled by the `models`.
- resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla estimation while
any valid `MLJ.ResamplingStrategy` will result in CV-OSE.
- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will
Expand All @@ -262,7 +262,7 @@ result in a data adaptive definition as described in [here](https://pubmed.ncbi.
```julia
using MLJLinearModels
models = (Y = LinearRegressor(), T = LogisticClassifier())
models = Dict(:Y => LinearRegressor(), :T => LogisticClassifier())
ose = OSE()
Ψ̂ₙ, cache = ose(Ψ, dataset)
```
Expand Down
17 changes: 9 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Create a NamedTuple containing default models to be used by downstream estimators.
Create a Dictionary containing default models to be used by downstream estimators.
Each provided model is prepended (in a `MLJ.Pipeline`) with an `MLJ.ContinuousEncoder`.
By default:
Expand All @@ -96,17 +96,18 @@ The following changes the default `Q_binary` to a `LogisticClassifier` and provi
```julia
using MLJLinearModels
models = (
special_y = RidgeRegressor(),
default_models(Q_binary=LogisticClassifier())...
models = default_models(
Q_binary = LogisticClassifier(),
special_y = RidgeRegressor()
)
```
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Q_binary_default = with_encoder(Q_binary),
Q_continuous_default = with_encoder(Q_continuous),
G_default = with_encoder(G)
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier(), kwargs...) = Dict(
:Q_binary_default => with_encoder(Q_binary),
:Q_continuous_default => with_encoder(Q_continuous),
:G_default => with_encoder(G),
(key => with_encoder(val) for (key, val) in kwargs)...
)

is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])
Expand Down
12 changes: 6 additions & 6 deletions test/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ end
mydiff(x, y) = y - x

jointestimand = JointEstimand(CM₀, CM₁)
models = (
Y = with_encoder(LinearRegressor()),
T = LogisticClassifier(lambda=0)
models = Dict(
:Y => with_encoder(LinearRegressor()),
:T => LogisticClassifier(lambda=0)
)
tmle = TMLEE(models=models)
ose = OSE(models=models)
Expand Down Expand Up @@ -102,9 +102,9 @@ end

@testset "Test compose multidimensional function" begin
dataset = make_dataset(;n=1000)
models = (
Y = with_encoder(LinearRegressor()),
T = LogisticClassifier(lambda=0)
models = Dict(
:Y => with_encoder(LinearRegressor()),
:T => LogisticClassifier(lambda=0)
)
tmle = TMLEE(models=models)
cache = Dict()
Expand Down
10 changes: 5 additions & 5 deletions test/counterfactual_mean_based/3points_interactions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ end
),
treatment_confounders = (T₁=[:W], T₂=[:W], T₃=[:W])
)
models = (
Y = with_encoder(InteractionTransformer(order=3) |> LinearRegressor()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0),
T₃ = LogisticClassifier(lambda=0)
models = Dict(
:Y => with_encoder(InteractionTransformer(order=3) |> LinearRegressor()),
:T₁ => LogisticClassifier(lambda=0),
:T₂ => LogisticClassifier(lambda=0),
:T₃ => LogisticClassifier(lambda=0)
)

tmle = TMLEE(models=models, machine_cache=true)
Expand Down
66 changes: 33 additions & 33 deletions test/counterfactual_mean_based/double_robustness_ate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ end
)

# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T = with_encoder(LogisticClassifier(lambda=0))
models = Dict(
:Y => with_encoder(MLJModels.DeterministicConstantRegressor()),
:T => with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -127,14 +127,14 @@ end
@test emptyIC(results.tmle, pval_threshold=0.9pval).IC == []
@test emptyIC(results.tmle, pval_threshold=1.1pval) === results.tmle
# The initial estimate is far away
naive = NAIVE(models.Y)
naive = NAIVE(models[:Y])
naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0)
@test naive_result == 0

# When Q is well specified but G is misspecified
models = (
Y = with_encoder(TreatmentTransformer() |> LinearRegressor()),
T = with_encoder(ConstantClassifier())
models = Dict(
:Y => with_encoder(TreatmentTransformer() |> LinearRegressor()),
:T => with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -149,22 +149,22 @@ end
treatment_confounders = (T=[:W],)
)
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(ConstantClassifier()),
T = with_encoder(LogisticClassifier(lambda=0))
models = Dict(
:Y => with_encoder(ConstantClassifier()),
:T => with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models, resampling=StratifiedCV())
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6)
test_mean_inf_curve_almost_zero(results.ose; atol=1e-6)
# The initial estimate is far away
naive = NAIVE(models.Y)
naive = NAIVE(models[:Y])
naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0)
@test naive_result == 0
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LogisticClassifier(lambda=0)),
T = with_encoder(ConstantClassifier())
models = Dict(
:Y => with_encoder(LogisticClassifier(lambda=0)),
:T => with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models, resampling=StratifiedCV())
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -180,24 +180,24 @@ end
treatment_confounders = (T=[:W₁, :W₂, :W₃],)
)
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T = with_encoder(LogisticClassifier(lambda=0))
models = Dict(
:Y => with_encoder(MLJModels.DeterministicConstantRegressor()),
:T => with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)

test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10)
test_mean_inf_curve_almost_zero(results.ose; atol=1e-10)
# The initial estimate is far away
naive = NAIVE(models.Y)
naive = NAIVE(models[:Y])
naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0)
@test naive_result == 0

# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LinearRegressor()),
T = with_encoder(ConstantClassifier())
models = Dict(
:Y => with_encoder(LinearRegressor()),
:T => with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -219,20 +219,20 @@ end
)
)
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0))
models = Dict(
:Y => with_encoder(MLJModels.DeterministicConstantRegressor()),
:T₁ => with_encoder(LogisticClassifier(lambda=0)),
:T₂ => with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0)
test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10)
test_mean_inf_curve_almost_zero(results.ose; atol=1e-10)
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LinearRegressor()),
T₁ = with_encoder(ConstantClassifier()),
T₂ = with_encoder(ConstantClassifier())
models = Dict(
:Y => with_encoder(LinearRegressor()),
:T₁ => with_encoder(ConstantClassifier()),
:T₂ => with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0)
Expand All @@ -256,10 +256,10 @@ end
test_mean_inf_curve_almost_zero(results.ose; atol=1e-10)

# When Q is well specified but G is misspecified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0)),
models = Dict(
:Y => with_encoder(MLJModels.DeterministicConstantRegressor()),
:T₁ => with_encoder(LogisticClassifier(lambda=0)),
:T₂ => with_encoder(LogisticClassifier(lambda=0)),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₀, dataset; verbosity=0)
Expand Down
Loading

0 comments on commit 14544c7

Please sign in to comment.