Skip to content

Commit

Permalink
fix estimads tests
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 4, 2023
1 parent 6afc9a9 commit e447253
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 42 deletions.
34 changes: 26 additions & 8 deletions src/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ CM₂ = CM(
)
```
"""
@option struct CM <: Estimand
@option struct ConditionalMean <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const CM = ConditionalMean

name(::Type{CM}) = "CM"

#####################################################################
### Average Treatment Effect ###
#####################################################################
Expand Down Expand Up @@ -106,12 +110,16 @@ ATE₂ = ATE(
)
```
"""
@option struct ATE <: Estimand
@option struct AverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const ATE = AverageTreatmentEffect

name(::Type{ATE}) = "ATE"

#####################################################################
### Interaction Average Treatment Effect ###
#####################################################################
Expand Down Expand Up @@ -142,20 +150,31 @@ IATE₁ = IATE(
)
```
"""
@option struct IATE <: Estimand
@option struct InteractionAverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const IATE = InteractionAverageTreatmentEffect

name(::Type{IATE}) = "IATE"

#####################################################################
### Methods ###
#####################################################################

CMCompositeEstimand = Union{CM, ATE, IATE}

Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand = println(io, T)
function Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand
param_string = string(
name(T),
"\n-----",
"\nOutcome: ", Ψ.outcome,
"\nTreatment: ", Ψ.treatment
)
println(io, param_string)
end

equations_to_fit::CMCompositeEstimand) = (outcome_equation(Ψ), (Ψ.scm[t] for t in treatments(Ψ))...)

Expand Down Expand Up @@ -236,10 +255,9 @@ namedtuples_from_dicts(d::Dict) =

function param_key::CMCompositeEstimand)
return (
join.confounders, "_"),
join(keys.treatment), "_"),
string.outcome),
join.covariates, "_")
join(values(confounders(Ψ))..., "_"),
join(treatments(Ψ), "_"),
string(outcome(Ψ)),
)
end

Expand Down
87 changes: 53 additions & 34 deletions test/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ end
outcome =:Ycat,
)
log_sequence = (
(:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."),
(:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."),
(:info, "Fitting Structural Equation corresponding to variable T₂."),
)
@test_logs log_sequence... fit!(Ψ, dataset, verbosity=1)
Expand All @@ -139,73 +137,94 @@ end
)
log_sequence = (
(:info, "Fitting Structural Equation corresponding to variable Ycont."),
(:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."),
(:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."),
)
@test_logs log_sequence... fit!(Ψ, dataset, verbosity=1)
@test scm.Ycat.mach isa Machine
@test scm.T₁.mach isa Machine
@test scm.T₂.mach isa Machine
@test scm.Ycont.mach isa Machine

# Change a model
scm.Ycont.model = TreatmentTransformer() |> LinearRegressor(fit_intercept=false)
log_sequence = (
(:info, "Fitting Structural Equation corresponding to variable Ycont."),
)
@test_logs log_sequence... fit!(Ψ, dataset, verbosity=1)
end

@testset "Test optimize_ordering" begin
rng = StableRNG(123)
scm = SCM(
SE(:T₁, [:W₁, :W₂]),
SE(:T₂, [:W₁, :W₂, :W₃]),
SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁]),
SE(:Y₂, [:T₂, :W₁, :W₂, :W₃]),
)
estimands = [
ATE(
outcome=:Y,
scm=scm,
outcome=:Y₁,
treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")),
confounders=[:W₁, :W₂]
),
IATE(
scm=scm,
outcome=:Y₁,
treatment=(T₁=(case=1, control=0), T₂=(case="AA", control="CC")),
),
ATE(
outcome=:Y,
scm=scm,
outcome=:Y₁,
treatment=(T₁=(case=1, control=0),),
confounders=[:W₁]
),
ATE(
scm=scm,
outcome=:Y₂,
treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")),
confounders=[:W₁, :W₂],
treatment=(T₂=(case="AC", control="CC"),),
),
CM(
outcome=:Y,
treatment=(T₁=0,),
confounders=[:W₁],
covariates=[:C₁]
scm=scm,
outcome=:Y₂,
treatment=(T₂="AC",),
),
IATE(
outcome=:Y,
treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")),
confounders=[:W₁, :W₂]
scm=scm,
outcome=:Y₁,
treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC"),),
),
CM(
outcome=:Y,
treatment=(T₁=0,),
confounders=[:W₁]
scm=scm,
outcome=:Y₂,
treatment=(T₂="CC",),
),
ATE(
scm=scm,
outcome=:Y₂,
treatment=(T₁=(case=1, control=0),),
confounders=[:W₁],
covariates=[:C₁]
treatment=(T₂=(case="AA", control="CC"),),
),
ATE(
scm=scm,
outcome=:Y₂,
treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC")),
confounders=[:W₁, :W₂],
covariates=[:C₂]
treatment=(T₂=(case="AA", control="AC"),),
),
]
# Test param_key
@test TMLE.param_key(estimands[1]) == ("W₁_W₂", "T₁_T₂", "Y₁")
@test TMLE.param_key(estimands[end]) == ("W₁_W₂_W₃", "T₂", "Y₂")
# Non mutating function
estimands = shuffle(rng, estimands)
ordered_estimands = optimize_ordering(estimands)
expected_ordering = [
ATE(:Y, (T₁ = (case = 1, control = 0),), [:W₁], Symbol[]),
CM(:Y, (T₁ = 0,), [:W₁], Symbol[]),
CM(:Y, (T₁ = 0,), [:W₁], [:C₁]),
ATE(:Y₂, (T₁ = (case = 1, control = 0),), [:W₁], [:C₁]),
ATE(:Y, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]),
IATE(:Y, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]),
ATE(:Y₂, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]),
ATE(:Y₂, (T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], [:C₂])
# Y₁
ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),)),
ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AC", control = "CC"))),
IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC"),)),
IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AA", control = "CC"))),
# Y₂
ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "AC"),)),
CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "CC",)),
ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "CC"),)),
CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "AC",)),
ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AC", control = "CC"),)),
]
@test ordered_estimands == expected_ordering
# Mutating function
Expand Down

0 comments on commit e447253

Please sign in to comment.