Skip to content

Commit

Permalink
Add an explicit treatment input argument to causaltree/forest (#776)
Browse files Browse the repository at this point in the history
* add an explicit treatment input argument to causaltree/forest
* reformat with black
* updated missing changes in #776 per @paullo's feedback
  • Loading branch information
jeongyoonlee committed Jul 5, 2024
1 parent a031566 commit b12c30b
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 65 deletions.
8 changes: 7 additions & 1 deletion causalml/inference/tree/_tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def __init__(

@abstractmethod
def fit(
self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated"
self,
X,
treatment,
y,
sample_weight=None,
check_input=True,
X_idx_sorted="deprecated",
):
pass

Expand Down
3 changes: 2 additions & 1 deletion causalml/inference/tree/_tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cdef class Criterion:

# Internal structures
cdef const DOUBLE_t[:, ::1] y # Values of y
cdef DOUBLE_t* treatment # Treatment assignment
cdef DOUBLE_t* sample_weight # Sample weights

cdef SIZE_t* samples # Sample indices in X, y
Expand Down Expand Up @@ -56,7 +57,7 @@ cdef class Criterion:
# statistics correspond to samples[start:pos] and samples[pos:end].

# Methods
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1
cdef int reset(self) nogil except -1
Expand Down
8 changes: 6 additions & 2 deletions causalml/inference/tree/_tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ cdef class Criterion:
def __setstate__(self, d):
pass

cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1:
"""Placeholder for a method which will initialize the criterion.
Expand All @@ -60,6 +60,8 @@ cdef class Criterion:
----------
y : array-like, dtype=DOUBLE_t
y is a buffer that can store values for n_outputs target variables
treatment : array-like, dtype=DOUBLE_t
The treatment assignment of each sample.
sample_weight : array-like, dtype=DOUBLE_t
The weight of each sample
weighted_n_samples : double
Expand Down Expand Up @@ -224,6 +226,7 @@ cdef class RegressionCriterion(Criterion):
The total number of samples to fit on
"""
# Default values
self.treatment = NULL
self.sample_weight = NULL

self.samples = NULL
Expand Down Expand Up @@ -259,7 +262,7 @@ cdef class RegressionCriterion(Criterion):
def __reduce__(self):
return (type(self), (self.n_outputs, self.n_samples), self.__getstate__())

cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* treatment, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1:
"""Initialize the criterion.
Expand All @@ -269,6 +272,7 @@ cdef class RegressionCriterion(Criterion):
"""
# Initialize fields
self.y = y
self.treatment = treatment
self.sample_weight = sample_weight
self.samples = samples
self.start = start
Expand Down
3 changes: 2 additions & 1 deletion causalml/inference/tree/_tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cdef class Splitter:
cdef SIZE_t end # End position for the current node

cdef const DOUBLE_t[:, ::1] y
cdef DOUBLE_t* treatment
cdef DOUBLE_t* sample_weight

# The samples vector `samples` is maintained by the Splitter object such
Expand All @@ -83,7 +84,7 @@ cdef class Splitter:

# Methods
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight) except -1
DOUBLE_t* treatment, DOUBLE_t* sample_weight) except -1

cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil except -1
Expand Down
13 changes: 11 additions & 2 deletions causalml/inference/tree/_tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ cdef class Splitter:
self.n_features = 0
self.feature_values = NULL

self.treatment = NULL
self.sample_weight = NULL

self.max_features = max_features
Expand All @@ -118,6 +119,7 @@ cdef class Splitter:
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter.
Expand All @@ -134,6 +136,9 @@ cdef class Splitter:
y : ndarray, dtype=DOUBLE_t
This is the vector of targets, or true labels, for the samples
treatment : DOUBLE_t*
The treatment assignments of the samples.
sample_weight : DOUBLE_t*
The weights of the samples, where higher weighted samples are fit
closer than lower weight samples. If not provided, all samples
Expand Down Expand Up @@ -180,6 +185,7 @@ cdef class Splitter:
self.y = y

self.sample_weight = sample_weight
self.treatment = treatment
return 0

cdef int node_reset(self, SIZE_t start, SIZE_t end,
Expand All @@ -203,6 +209,7 @@ cdef class Splitter:
self.end = end

self.criterion.init(self.y,
self.treatment,
self.sample_weight,
self.weighted_n_samples,
self.samples,
Expand Down Expand Up @@ -243,6 +250,7 @@ cdef class BaseDenseSplitter(Splitter):
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter
Expand All @@ -251,7 +259,7 @@ cdef class BaseDenseSplitter(Splitter):
"""

# Call parent init
Splitter.init(self, X, y, sample_weight)
Splitter.init(self, X, y, treatment, sample_weight)

self.X = X
return 0
Expand Down Expand Up @@ -802,14 +810,15 @@ cdef class BaseSparseSplitter(Splitter):
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* treatment,
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
# Call parent init
Splitter.init(self, X, y, sample_weight)
Splitter.init(self, X, y, treatment, sample_weight)

if not isinstance(X, csc_matrix):
raise ValueError("X should be in csc format")
Expand Down
2 changes: 2 additions & 0 deletions causalml/inference/tree/_tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ cdef class TreeBuilder:
Tree tree,
object X,
cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=*,
)

cdef _check_input(
self,
object X,
cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight,
)
19 changes: 14 additions & 5 deletions causalml/inference/tree/_tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ cdef class TreeBuilder:
"""Interface for different tree building strategies."""

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""
pass

cdef inline _check_input(self, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight):
"""Check input dtype, layout and format"""
if issparse(X):
Expand All @@ -122,13 +124,16 @@ cdef class TreeBuilder:
if y.dtype != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

if treatment.dtype != DOUBLE or not treatment.flags.contiguous:
treatment = np.ascontiguousarray(treatment, dtype=DOUBLE)

if (sample_weight is not None and
(sample_weight.dtype != DOUBLE or
not sample_weight.flags.contiguous)):
sample_weight = np.asarray(sample_weight, dtype=DOUBLE,
order="C")

return X, y, sample_weight
return X, y, treatment, sample_weight

# Depth first builder ---------------------------------------------------------

Expand All @@ -146,12 +151,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -175,7 +182,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef double min_impurity_decrease = self.min_impurity_decrease

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef SIZE_t start
cdef SIZE_t end
Expand Down Expand Up @@ -328,12 +335,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, cnp.ndarray y,
cnp.ndarray treatment,
cnp.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -346,7 +355,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_split = self.min_samples_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef vector[FrontierRecord] frontier
cdef FrontierRecord record
Expand Down
21 changes: 15 additions & 6 deletions causalml/inference/tree/causal/_builder.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray treatment,
np.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -80,7 +82,7 @@ cdef class DepthFirstCausalTreeBuilder(TreeBuilder):
cdef double min_impurity_decrease = self.min_impurity_decrease

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef SIZE_t start
cdef SIZE_t end
Expand Down Expand Up @@ -239,13 +241,20 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder):
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None):
cpdef build(
self,
Tree tree,
object X,
np.ndarray y,
np.ndarray treatment,
np.ndarray sample_weight=None
):
"""Build a decision tree from the training set (X, y)."""

# check input
X, y, sample_weight = self._check_input(X, y, sample_weight)
X, y, treatment, sample_weight = self._check_input(X, y, treatment, sample_weight)

cdef DOUBLE_t* treatment_ptr = <DOUBLE_t*> treatment.data
cdef DOUBLE_t* sample_weight_ptr = NULL
if sample_weight is not None:
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
Expand All @@ -258,7 +267,7 @@ cdef class BestFirstCausalTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_split = self.min_samples_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr)
splitter.init(X, y, treatment_ptr, sample_weight_ptr)

cdef vector[FrontierRecord] frontier
cdef FrontierRecord record
Expand Down
Loading

0 comments on commit b12c30b

Please sign in to comment.