Skip to content

Commit

Permalink
add more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 7, 2023
1 parent eb180f7 commit 3b6e05e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 5 deletions.
9 changes: 9 additions & 0 deletions src/adjustment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ struct BackdoorAdjustment <: AdjustmentMethod
outcome_extra::Vector{Symbol}
end

"""
BackdoorAdjustment(;outcome_extra=[])
The adjustment set for each treatment variable is simply the set of direct parents in the
associated structural model.
`outcome_extra` are optional additional variables that can be used to fit the outcome model
in order to improve inference.
"""
BackdoorAdjustment(;outcome_extra=[]) = BackdoorAdjustment(outcome_extra)

"""
Expand Down
80 changes: 75 additions & 5 deletions src/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,32 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ.
"""
abstract type Estimand end


#####################################################################
### Conditional Mean ###
#####################################################################

"""
# Conditional Mean / CM
## Definition
``CM(Y, T=t) = E[Y|do(T=t)]``
## Constructors
- CM(;scm::SCM, outcome, treatment)
- CM(scm::SCM; outcome, treatment)
where:
- scm: is a `StructuralCausalModel` (see [`SCM`](@ref))
- outcome: is a `Symbol`
- treatment: is a `NamedTuple`
## Example
Ψ = CM(scm, outcome=:Y, treatment=(T=1,))
"""
struct ConditionalMean <: Estimand
scm::StructuralCausalModel
outcome::Symbol
Expand All @@ -23,15 +44,36 @@ end

const CM = ConditionalMean

CM(;scm, outcome, treatment) = CM(scm, outcome, treatment)
CM(scm; outcome, treatment) = CM(scm, outcome, treatment)
CM(;scm::SCM, outcome, treatment) = CM(scm, outcome, treatment)
CM(scm::SCM; outcome, treatment) = CM(scm, outcome, treatment)

name(::Type{CM}) = "CM"

#####################################################################
### Average Treatment Effect ###
#####################################################################
"""
# Average Treatment Effect / ATE
## Definition
``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)``
## Constructors
- ATE(;scm::SCM, outcome, treatment)
- ATE(scm::SCM; outcome, treatment)
where:
- scm: is a `StructuralCausalModel` (see [`SCM`](@ref))
- outcome: is a `Symbol`
- treatment: is a `NamedTuple`
## Example
Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),)
"""
struct AverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
Expand All @@ -44,15 +86,39 @@ end

const ATE = AverageTreatmentEffect

ATE(;scm, outcome, treatment) = ATE(scm, outcome, treatment)
ATE(scm; outcome, treatment) = ATE(scm, outcome, treatment)
ATE(;scm::SCM, outcome, treatment) = ATE(scm, outcome, treatment)
ATE(scm::SCM; outcome, treatment) = ATE(scm, outcome, treatment)

name(::Type{ATE}) = "ATE"

#####################################################################
### Interaction Average Treatment Effect ###
#####################################################################

"""
# Interaction Average Treatment Effect / IATE
## Definition
For two treatments with settings (1, 0):
``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]``
## Constructors
- IATE(;scm::SCM, outcome, treatment)
- IATE(scm::SCM; outcome, treatment)
where:
- scm: is a `StructuralCausalModel` (see [`SCM`](@ref))
- outcome: is a `Symbol`
- treatment: is a `NamedTuple`
## Example
Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0))
"""
struct InteractionAverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
Expand All @@ -74,6 +140,8 @@ name(::Type{IATE}) = "IATE"
### Methods ###
#####################################################################

AVAILABLE_ESTIMANDS = (CM, ATE, IATE)

CMCompositeEstimand = Union{CM, ATE, IATE}

VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM."))
Expand Down Expand Up @@ -126,6 +194,8 @@ end
"""
optimize_ordering!(estimands::Vector{<:Estimand})
Optimizes the order of the `estimands` to maximize reuse of
fitted equations in the associated SCM.
"""
optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=param_key)

Expand Down
21 changes: 21 additions & 0 deletions src/estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,27 @@ function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=Backdo
end
end

"""
tmle(Ψ::CMCompositeEstimand, dataset;
adjustment_method=BackdoorAdjustment(),
verbosity=1,
force=false,
threshold=1e-8,
weighted_fluctuation=false
)
Performs Targeted Minimum Loss Based Estimation of the target estimand.
## Arguments
- Ψ: An estimand of interest.
- dataset: A table respecting the `Tables.jl` interface.
- adjustment_method: A confounding adjustment method.
- verbosity: Level of logging.
- force: To force refit of machines in the SCM .
- threshold: The balancing score will be bounded to respect this threshold.
- weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability.
"""
function tmle::CMCompositeEstimand, dataset;
adjustment_method=BackdoorAdjustment(),
verbosity=1,
Expand Down
63 changes: 63 additions & 0 deletions src/scm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
#####################################################################
### Structural Equation ###
#####################################################################
"""
# Structural Equation / SE
## Constructors
- SE(outcome, parents; model=nothing)
- SE(;outcome, parents, model=nothing)
## Examples
eq = SE(:Y, [:T, :W])
eq = SE(:Y, [:T, :W], model = LinearRegressor())
"""
mutable struct StructuralEquation
outcome::Symbol
parents::Vector{Symbol}
Expand Down Expand Up @@ -95,6 +109,22 @@ parents(se::SE) = se.parents

AlreadyAssignedError(key) = ArgumentError(string("Variable ", key, " is already assigned in the SCM."))

"""
# Structural Causal Model / SCM
## Constructors
SCM(;equations=Dict{Symbol, SE}())
SCM(equations::Vararg{SE})
## Examples
scm = SCM(
SE(:Y, [:T, :W, :C]),
SE(:T, [:W])
)
"""
struct StructuralCausalModel
equations::Dict{Symbol, StructuralEquation}
end
Expand Down Expand Up @@ -191,6 +221,20 @@ end
vcat_covariates(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders)
vcat_covariates(treatment, confounders, covariates) = vcat(treatment, confounders, covariates)

"""
StaticConfoundedModel(
outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}};
covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing,
outcome_model = TreatmentTransformer() |> LinearRegressor(),
treatment_model = LinearBinaryClassifier()
)
Defines a classic Structural Causal Model with one outcome, one treatment,
a set of confounding variables and optional covariates influencing the outcome only.
The `outcome_model` and `treatment_model` define the relationship between
the outcome (resp. treatment) and their ancestors.
"""
function StaticConfoundedModel(
outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}};
covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing,
Expand All @@ -210,6 +254,25 @@ function StaticConfoundedModel(
return StructuralCausalModel(Yeq, Teq)
end

"""
StaticConfoundedModel(
outcomes::Vector{Symbol},
treatments::Vector{Symbol},
confounders::Union{Symbol, AbstractVector{Symbol}};
covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing,
outcome_model = TreatmentTransformer() |> LinearRegressor(),
treatment_model = LinearBinaryClassifier()
)
Defines a classic Structural Causal Model with multiple outcomes, multiple treatments,
a set of confounding variables and optional covariates influencing the outcomes only.
All treatments are assumed to be direct parents of all outcomes. The confounding variables
are shared for all treatments.
The `outcome_model` and `treatment_model` define the relationships between
the outcomes (resp. treatments) and their ancestors.
"""
function StaticConfoundedModel(
outcomes::Vector{Symbol},
treatments::Vector{Symbol},
Expand Down

0 comments on commit 3b6e05e

Please sign in to comment.