Skip to content

Commit

Permalink
tests working locally
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 7, 2023
1 parent b7f8f6f commit 9701cd7
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 67 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.11.4"
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Expand All @@ -27,7 +26,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractDifferentiation = "0.4, 0.5"
CategoricalArrays = "0.10"
Configurations = "0.17"
Distributions = "0.25"
GLM = "1.8.2"
HypothesisTests = "0.10"
Expand All @@ -36,10 +34,10 @@ MLJBase = "0.19, 0.20, 0.21"
MLJGLMInterface = "0.3.4"
MLJModels = "0.15, 0.16"
Missings = "1.0"
PrecompileTools = "1.1.1"
PrettyTables = "2.2"
TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4"
Zygote = "0.6"
PrecompileTools = "1.1.1"
julia = "1.6, 1.7, 1"
2 changes: 0 additions & 2 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using Distributions
using Zygote
using LogExpFunctions
using YAML
using Configurations
using PrecompileTools
using PrettyTables
using Random
Expand Down Expand Up @@ -44,7 +43,6 @@ include("estimands.jl")
include("utils.jl")
include("estimation.jl")
include("estimate.jl")
include("configuration.jl")

# #############################################################################
# PRECOMPILATION WORKLOAD
Expand Down
17 changes: 13 additions & 4 deletions src/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ CM₂ = CM(
)
```
"""
@option struct ConditionalMean <: Estimand
struct ConditionalMean <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const CM = ConditionalMean

CM(;scm, outcome, treatment) = CM(scm, outcome, treatment)
CM(scm; outcome, treatment) = CM(scm, outcome, treatment)

name(::Type{CM}) = "CM"

#####################################################################
Expand Down Expand Up @@ -110,14 +113,17 @@ ATE₂ = ATE(
)
```
"""
@option struct AverageTreatmentEffect <: Estimand
struct AverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const ATE = AverageTreatmentEffect

ATE(;scm, outcome, treatment) = ATE(scm, outcome, treatment)
ATE(scm; outcome, treatment) = ATE(scm, outcome, treatment)

name(::Type{ATE}) = "ATE"

#####################################################################
Expand Down Expand Up @@ -150,14 +156,17 @@ IATE₁ = IATE(
)
```
"""
@option struct InteractionAverageTreatmentEffect <: Estimand
struct InteractionAverageTreatmentEffect <: Estimand
scm::StructuralCausalModel
outcome::Symbol
treatment::NamedTuple
end

const IATE = InteractionAverageTreatmentEffect

IATE(;scm, outcome, treatment) = IATE(scm, outcome, treatment)
IATE(scm; outcome, treatment) = IATE(scm, outcome, treatment)

name(::Type{IATE}) = "IATE"

#####################################################################
Expand All @@ -166,7 +175,7 @@ name(::Type{IATE}) = "IATE"

CMCompositeEstimand = Union{CM, ATE, IATE}

function Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand
function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand
param_string = string(
name(T),
"\n-----",
Expand Down
17 changes: 11 additions & 6 deletions src/scm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ SelfReferringEquationError(outcome) =

NoModelError(eq::SE) = ArgumentError(string("It seems the following structural equation needs to be fitted.\n",
" Please provide a suitable model for it :\n\t", eq))


fit_message(eq::SE) = string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")

function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false)
eq.model !== nothing || throw(NoModelError(eq))
# Fit if never fitted or if new model
if eq.mach === nothing || eq.model != eq.mach.model
verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), "."))
verbosity >= 1 && @info(fit_message(eq))
data = nomissing(dataset, vcat(parents(eq), outcome(eq)))
X = selectcols(data, parents(eq))
y = Tables.getcolumn(data, outcome(eq))
Expand All @@ -35,6 +37,7 @@ function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false)
eq.mach = mach
# Otherwise only fit if force is true
else
verbosity >= 1 && force === true && @info(fit_message(eq))
MLJBase.fit!(eq.mach, verbosity=verbosity-1, force=force)
end
end
Expand All @@ -49,7 +52,7 @@ function string_repr(eq::SE; subscript="")
return eq_string
end

Base.show(io::IO, eq::SE) = println(io, string_repr(eq))
Base.show(io::IO, ::MIME"text/plain", eq::SE) = println(io, string_repr(eq))

assign_model!(eq::SE, model::Nothing) = nothing
assign_model!(eq::SE, model::Model) = eq.model = model
Expand All @@ -69,8 +72,10 @@ end

const SCM = StructuralCausalModel

StructuralCausalModel(equations::Vararg{SE}) =
StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations))
SCM() = SCM(Dict{Symbol, SE}())

SCM(equations::Vararg{SE}) =
SCM(Dict(outcome(eq) => eq for eq in equations))


equations(scm::SCM) = scm.equations
Expand All @@ -88,7 +93,7 @@ function string_repr(scm::SCM)
return scm_string
end

Base.show(io::IO, scm::SCM) = println(io, string_repr(scm))
Base.show(io::IO, ::MIME"text/plain", scm::SCM) = println(io, string_repr(scm))


function Base.push!(scm::SCM, eq::SE)
Expand Down
38 changes: 0 additions & 38 deletions test/configuration.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/double_robustness_ate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ end
outcome = :Y,
treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)),
)
tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=1)
tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0)
test_coverage(tmle_result, ATE₁₁₋₀₀)
test_fluct_decreases_risk(Ψ, fluctuation_mach)
test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test

@time begin
@test include("scm.jl")
@test include("non_regression_test.jl")
@test include("utils.jl")
@test include("double_robustness_ate.jl")
Expand All @@ -11,5 +12,4 @@ using Test
@test include("missing_management.jl")
@test include("composition.jl")
@test include("treatment_transformer.jl")
@test include("configuration.jl")
end
15 changes: 3 additions & 12 deletions test/scm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,10 @@ end
@test isdefined(eq.mach, :data)
end
# Refit will not do anything
nofit_log_sequence = (
(:info, "Structural Equation corresponding to variable Ycont already fitted, skipping. Set `force=true` to force refit."),
(:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."),
(:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."),
(:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."),
)
nofit_log_sequence = ()
@test_logs nofit_log_sequence... fit!(scm, dataset, verbosity = 1)
# Force refit and set cache to false
@test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true, cache=false)
for (key, eq) in equations(scm)
@test eq.mach isa Machine
@test !isdefined(eq.mach, :data)
end
# Force refit
@test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true)

# Reset scm
reset!(scm)
Expand Down

0 comments on commit 9701cd7

Please sign in to comment.