Skip to content

Commit

Permalink
updated missing changes in #776 per @paullo's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed May 26, 2024
1 parent 3bbce11 commit ef3a41c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
8 changes: 7 additions & 1 deletion causalml/inference/tree/_tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def __init__(

@abstractmethod
def fit(
self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated"
self,
X,
treatment,
y,
sample_weight=None,
check_input=True,
X_idx_sorted="deprecated",
):
pass

Expand Down
2 changes: 1 addition & 1 deletion causalml/inference/tree/causal/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _support_missing_values(self, X) -> bool:
def fit(
self,
X,
y,
treatment,
y,
sample_weight=None,
check_input=True,
X_idx_sorted="deprecated",
Expand Down
2 changes: 1 addition & 1 deletion causalml/inference/tree/causal/causaltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def fit(
treatment (np.ndarray): treatment vector
y (np.ndarray): outcome vector
sample_weight (np.ndarray): sample_weight
check_input (bool, optional), default=False
check_input (bool, optional): default=False
Returns:
self
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_causal_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_fit_predict(
def test_predict(self, generate_regression_data):
y, X, treatment, tau, b, e = generate_regression_data(mode=2)
ctree = self.prepare_model()
ctree.fit(X=X, y=y, treatment=treatment)
ctree.fit(X=X, treatment=treatment, y=y)
y_pred = ctree.predict(X[:1, :])
y_pred_with_outcomes = ctree.predict(X[:1, :], with_outcomes=True)
assert y_pred.shape == (1,)
Expand Down

0 comments on commit ef3a41c

Please sign in to comment.