diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 8875463..2aa9073 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -265,6 +265,12 @@ get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_val get_treatment_settings(::typeof(CM), treatments_unique_values) = treatments_unique_values +""" +Allows for treatment settings to be generated from a Dictionary +""" +get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values::Dict{Symbol, Vector{UInt8}}) = + Dict(k => collect(zip(v[1:end-1], v[2:end])) for (k, v) in treatments_unique_values) + get_treatment_setting(combo::Tuple{Vararg{Tuple}}) = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case ∈ combo] get_treatment_setting(combo) = collect(combo) @@ -332,6 +338,55 @@ function _factorialEstimand( return JointEstimand(components...) end +""" +Dictionary implementation +""" +function _factorialEstimand( + constructor, + treatments_settings::Dict{Symbol, Vector{Tuple{UInt8, UInt8}}}, + outcome; + confounders=nothing, + outcome_extra_covariates=nothing, + freq_table=nothing, + positivity_constraint=nothing, + verbosity=1 + ) + components = [] + + # Get the keys and values from the dictionary + keys_ = collect(keys(treatments_settings)) + + # Define the names for the NamedTuple + names = Tuple(keys_) + + # Iterate through the product of all treatment settings + for combo ∈ Iterators.product(values(treatments_settings)...) + # Convert the combination back to a NamedTuple + treatment_values = NamedTuple{names}(TMLE.get_treatment_setting(combo)) + + # Construct the Ψ object + Ψ = constructor( + outcome=outcome, + treatment_values=treatment_values, + treatment_confounders=confounders, + outcome_extra_covariates=outcome_extra_covariates + ) + + # Check the positivity constraint + if TMLE.satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint) + push!(components, Ψ) + else + verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.") + end + end + + if length(components) == 0 + throw(ArgumentError("No component passed the positivity constraint.")) + end + + return JointEstimand(components...) +end + """ factorialEstimand( constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},