Skip to content

Commit

Permalink
remove fig resizing from FIGS
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 7, 2023
1 parent 82cdb46 commit dba8ba1
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 65 deletions.
82 changes: 51 additions & 31 deletions imodels/experimental/figs_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from imodels.tree.viz_utils import extract_sklearn_tree_from_figs

plt.rcParams['figure.dpi'] = 300


class Node:
def __init__(self, feature: int = None, threshold: int = None,
Expand Down Expand Up @@ -45,7 +43,8 @@ def __init__(self, feature: int = None, threshold: int = None,
def update_values(self, X, y):
self.value = y.mean()
if self.threshold is not None:
right_indicator = np.apply_along_axis(lambda x: x[self.feature] > self.threshold, 1, X)
right_indicator = np.apply_along_axis(
lambda x: x[self.feature] > self.threshold, 1, X)
X_right = X[right_indicator, :]
X_left = X[~right_indicator, :]
y_right = y[right_indicator]
Expand All @@ -61,9 +60,11 @@ def shrink(self, reg_param, cum_sum=0):
if self.left is None: # if leaf node, change prediction
self.value = cum_sum
else:
shrunk_diff = (self.left.value - self.value) / (1 + reg_param / self.n_samples)
shrunk_diff = (self.left.value - self.value) / \
(1 + reg_param / self.n_samples)
self.left.shrink(reg_param, cum_sum + shrunk_diff)
shrunk_diff = (self.right.value - self.value) / (1 + reg_param / self.n_samples)
shrunk_diff = (self.right.value - self.value) / \
(1 + reg_param / self.n_samples)
self.right.shrink(reg_param, cum_sum + shrunk_diff)

def setattrs(self, **kwargs):
Expand Down Expand Up @@ -132,7 +133,7 @@ def _init_decision_function(self):
"""
# used by sklearn GridSearchCV, BaggingClassifier
if self.prediction_task == 'classification':
decision_function = lambda x: self.predict_proba(x)[:, 1]
def decision_function(x): return self.predict_proba(x)[:, 1]
elif self.prediction_task == 'regression':
decision_function = self.predict

Expand Down Expand Up @@ -166,7 +167,8 @@ def _construct_node_linear(self, X, y, idxs, tree_num=0, sample_weight=None):
feature=None, threshold=None,
impurity_reduction=-1, split_or_linear='split') # leaf node that just returns its value
else:
assert isinstance(best_linear_coef, float), 'coef should be a float'
assert isinstance(best_linear_coef,
float), 'coef should be a float'
return Node(idxs=idxs, value=best_linear_coef, tree_num=tree_num,
feature=best_feature, threshold=None,
impurity_reduction=impurity_reduction, split_or_linear='linear')
Expand All @@ -178,7 +180,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
RIGHT = 2

# fit stump
stump = tree.DecisionTreeRegressor(max_depth=1, max_features=max_features)
stump = tree.DecisionTreeRegressor(
max_depth=1, max_features=max_features)
if sample_weight is not None:
sample_weight = sample_weight[idxs]
stump.fit(X[idxs], y[idxs], sample_weight=sample_weight)
Expand All @@ -201,10 +204,10 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m

# split node
impurity_reduction = (
impurity[SPLIT] -
impurity[LEFT] * n_node_samples[LEFT] / n_node_samples[SPLIT] -
impurity[RIGHT] * n_node_samples[RIGHT] / n_node_samples[SPLIT]
) * idxs.sum()
impurity[SPLIT] -
impurity[LEFT] * n_node_samples[LEFT] / n_node_samples[SPLIT] -
impurity[RIGHT] * n_node_samples[RIGHT] / n_node_samples[SPLIT]
) * idxs.sum()

node_split = Node(idxs=idxs, value=value[SPLIT], tree_num=tree_num,
feature=feature[SPLIT], threshold=threshold[SPLIT],
Expand All @@ -216,7 +219,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
idxs_left = idxs_split & idxs
idxs_right = ~idxs_split & idxs
node_left = Node(idxs=idxs_left, value=value[LEFT], tree_num=tree_num)
node_right = Node(idxs=idxs_right, value=value[RIGHT], tree_num=tree_num)
node_right = Node(
idxs=idxs_right, value=value[RIGHT], tree_num=tree_num)
node_split.setattrs(left_temp=node_left, right_temp=node_right, )
return node_split

Expand All @@ -231,7 +235,8 @@ def fit(self, X, y=None, feature_names=None, verbose=False, sample_weight=None):
"""

if self.prediction_task == 'classification':
self.classes_, y = np.unique(y, return_inverse=True) # deals with str inputs
self.classes_, y = np.unique(
y, return_inverse=True) # deals with str inputs
X, y = check_X_y(X, y)
y = y.astype(float)
if feature_names is not None:
Expand All @@ -252,7 +257,8 @@ def _update_tree_preds(n_iter):
if not tree_num_2_ == tree_num_:
y_residuals_per_tree[tree_num_] -= y_predictions_per_tree[tree_num_2_]
tree_.update_values(X, y_residuals_per_tree[tree_num_])
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[tree_num_], X)
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[
tree_num_], X)

# set up initial potential_splits
# everything in potential_splits either is_root (so it can be added directly to self.trees_)
Expand All @@ -267,13 +273,15 @@ def _update_tree_preds(n_iter):
potential_splits.append(node_init_linear)
for node in potential_splits:
node.setattrs(is_root=True)
potential_splits = sorted(potential_splits, key=lambda x: x.impurity_reduction)
potential_splits = sorted(
potential_splits, key=lambda x: x.impurity_reduction)

# start the greedy fitting algorithm
finished = False
while len(potential_splits) > 0 and not finished:
# print('potential_splits', [str(s) for s in potential_splits])
split_node = potential_splits.pop() # get node with max impurity_reduction (since it's sorted)
# get node with max impurity_reduction (since it's sorted)
split_node = potential_splits.pop()

# don't split on node
if split_node.impurity_reduction < self.min_impurity_decrease:
Expand Down Expand Up @@ -304,16 +312,19 @@ def _update_tree_preds(n_iter):
if split_node.split_or_linear == 'split':
# assign left_temp, right_temp to be proper children
# (basically adds them to tree in predict method)
split_node.setattrs(left=split_node.left_temp, right=split_node.right_temp)
split_node.setattrs(left=split_node.left_temp,
right=split_node.right_temp)

# add children to potential_splits
potential_splits.append(split_node.left)
potential_splits.append(split_node.right)

# update predictions for altered tree
for tree_num_ in range(len(self.trees_)):
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[tree_num_], X)
y_predictions_per_tree[-1] = np.zeros(X.shape[0]) # dummy 0 preds for possible new trees
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[
tree_num_], X)
# dummy 0 preds for possible new trees
y_predictions_per_tree[-1] = np.zeros(X.shape[0])

# update residuals for each tree
# -1 is key for potential new tree
Expand Down Expand Up @@ -352,7 +363,8 @@ def _update_tree_preds(n_iter):
)
elif potential_split.split_or_linear == 'linear':
assert potential_split.is_root, 'Currently, linear node only supported as root'
assert potential_split.idxs.sum() == X.shape[0], 'Currently, linear node only supported as root'
assert potential_split.idxs.sum(
) == X.shape[0], 'Currently, linear node only supported as root'
potential_split_updated = self._construct_node_linear(idxs=potential_split.idxs,
X=X,
y=y_target,
Expand All @@ -371,7 +383,8 @@ def _update_tree_preds(n_iter):
potential_splits_new.append(potential_split)

# sort so largest impurity reduction comes last (should probs make this a heap later)
potential_splits = sorted(potential_splits_new, key=lambda x: x.impurity_reduction)
potential_splits = sorted(
potential_splits_new, key=lambda x: x.impurity_reduction)
if verbose:
print(self)
if self.max_rules is not None and self.complexity_ >= self.max_rules:
Expand All @@ -383,9 +396,11 @@ def _update_tree_preds(n_iter):
# potentially fit linear model on the tree preds
if self.posthoc_ridge:
if self.prediction_task == 'regression':
self.weighted_model_ = RidgeCV(alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
self.weighted_model_ = RidgeCV(
alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
elif self.prediction_task == 'classification':
self.weighted_model_ = RidgeClassifierCV(alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
self.weighted_model_ = RidgeClassifierCV(
alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
X_feats = self._extract_tree_predictions(X)
self.weighted_model_.fit(X_feats, y)
return self
Expand All @@ -402,7 +417,8 @@ def _tree_to_str(self, root: Node, prefix=''):
pprefix)

def __str__(self):
s = '------------\n' + '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
s = '------------\n' + \
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
for i in range(len(self.feature_names_))[::-1]:
s = s.replace(f'X_{i}', self.feature_names_[i])
Expand All @@ -425,14 +441,16 @@ def predict_proba(self, X):
return NotImplemented
elif self.posthoc_ridge and self.weighted_model_: # note, during fitting don't use the weighted moel
X_feats = self._extract_tree_predictions(X)
d = self.weighted_model_.decision_function(X_feats) # for 2 classes, this (n_samples,)
d = self.weighted_model_.decision_function(
X_feats) # for 2 classes, this (n_samples,)
probs = np.exp(d) / (1 + np.exp(d))
return np.vstack((1 - probs, probs)).transpose()
else:
preds = np.zeros(X.shape[0])
for tree in self.trees_:
preds += self._predict_tree(tree, X)
preds = np.clip(preds, a_min=0., a_max=1.) # constrain to range of probabilities
# constrain to range of probabilities
preds = np.clip(preds, a_min=0., a_max=1.)
return np.vstack((1 - preds, preds)).transpose()

def _extract_tree_predictions(self, X):
Expand Down Expand Up @@ -473,7 +491,7 @@ def _predict_tree_single_point(root: Node, x):

def plot(self, cols=2, feature_names=None, filename=None, label="all",
impurity=False, tree_number=None, dpi=150, fig_size=None):
is_single_tree = len(self.trees_) < 2 or tree_number is not None
is_single_tree = len(self.trees_) < 2 or tree_number is not None
n_cols = int(cols)
n_rows = int(np.ceil(len(self.trees_) / n_cols))
# if is_single_tree:
Expand All @@ -486,7 +504,7 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
fig.set_size_inches(fig_size, fig_size)
criterion = "squared_error" if self.prediction_task == "regression" else "gini"
n_classes = 1 if self.prediction_task == 'regression' else 2
ax_size = int(len(self.trees_))#n_cols * n_rows
ax_size = int(len(self.trees_)) # n_cols * n_rows
for i in range(n_plots):
r = i // n_cols
c = i % n_cols
Expand All @@ -496,8 +514,10 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
else:
ax = axs
try:
dt = extract_sklearn_tree_from_figs(self, i if tree_number is None else tree_number, n_classes)
plot_tree(dt, ax=ax, feature_names=feature_names, label=label, impurity=impurity)
dt = extract_sklearn_tree_from_figs(
self, i if tree_number is None else tree_number, n_classes)
plot_tree(dt, ax=ax, feature_names=feature_names,
label=label, impurity=impurity)
except IndexError:
ax.axis('off')
continue
Expand Down
Loading

2 comments on commit dba8ba1

@mepland
Copy link
Collaborator

@mepland mepland commented on dba8ba1 Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csinva are you using a new code formatter now? LMK if I should setup the same.

@csinva
Copy link
Owner Author

@csinva csinva commented on dba8ba1 Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops no sorry made this commit on a new machine that I hadn't set up the formatter for yet. No need to change what you're doing!

Please sign in to comment.