Skip to content

Commit

Permalink
_factorialEstimand and get_treatment_settings overloaded to handle di…
Browse files Browse the repository at this point in the history
…ctionaries (outputs NamedTuple)
  • Loading branch information
joshua-slaughter committed Jul 25, 2024
1 parent b694b7b commit 4f25ca9
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_val

get_treatment_settings(::typeof(CM), treatments_unique_values) = treatments_unique_values

"""
Allows for treatment settings to be generated from a Dictionary
"""
get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values::Dict{Symbol, Vector{UInt8}}) =
Dict(k => collect(zip(v[1:end-1], v[2:end])) for (k, v) in treatments_unique_values)

get_treatment_setting(combo::Tuple{Vararg{Tuple}}) = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]

get_treatment_setting(combo) = collect(combo)
Expand Down Expand Up @@ -332,6 +338,55 @@ function _factorialEstimand(
return JointEstimand(components...)
end

"""
Dictionary implementation
"""
function _factorialEstimand(
constructor,
treatments_settings::Dict{Symbol, Vector{Tuple{UInt8, UInt8}}},
outcome;
confounders=nothing,
outcome_extra_covariates=nothing,
freq_table=nothing,
positivity_constraint=nothing,
verbosity=1
)
components = []

# Get the keys and values from the dictionary
keys_ = collect(keys(treatments_settings))

# Define the names for the NamedTuple
names = Tuple(keys_)

# Iterate through the product of all treatment settings
for combo Iterators.product(values(treatments_settings)...)
# Convert the combination back to a NamedTuple
treatment_values = NamedTuple{names}(TMLE.get_treatment_setting(combo))

# Construct the Ψ object
Ψ = constructor(
outcome=outcome,
treatment_values=treatment_values,
treatment_confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates
)

# Check the positivity constraint
if TMLE.satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint)
push!(components, Ψ)
else
verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.")
end
end

if length(components) == 0
throw(ArgumentError("No component passed the positivity constraint."))
end

return JointEstimand(components...)
end

"""
factorialEstimand(
constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},
Expand Down

0 comments on commit 4f25ca9

Please sign in to comment.