Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf catvars from tlapusan #265

Merged
merged 34 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
39ee126
262#Fix a unit test
tlapusan Feb 14, 2023
c5dc1c9
262#Adapt get_split_node_heights for categorical values
tlapusan Feb 14, 2023
e6e5c4e
262# Fix an issue for choosing where to draw the wedge
tlapusan Feb 15, 2023
c8bcf86
262# Display categorical split nodes for categorical trees
tlapusan Feb 15, 2023
60ac36e
262# Handle categorical features for regression tree (except when hig…
tlapusan Feb 16, 2023
ff0de3a
262# Display the corect wedge in case of highlighted path
tlapusan Feb 17, 2023
6b746a7
Turn off splits for catvars in 1D classifier feature space.
parrt Feb 17, 2023
b7bb267
Turn off splits for catvars in 1D classifier feature space.
parrt Feb 17, 2023
9f7a1d6
working on catvar feature space
parrt Feb 17, 2023
c75d771
raise exception for catvars in feature space plots.
parrt Feb 17, 2023
ff2a171
bump version to 2.2.0
parrt Feb 17, 2023
44afb3c
remove unnecessary code.
parrt Feb 17, 2023
2741d21
Fix test for non-numeric numpy arrays
parrt Feb 17, 2023
d7044a6
watch out for <OOD> symbol used by set operations in tf-df
parrt Feb 18, 2023
4d218d8
tweak
parrt Feb 18, 2023
752f39c
rm debugging print
parrt Feb 18, 2023
4e32a78
more general avoidance of <OOD> catvar value from tf; clean up whites…
parrt Feb 18, 2023
a234b6f
add test py file regr tree with catvar node
parrt Feb 18, 2023
6bb0806
For regressor + catvar split node, split might be on an unordered set…
parrt Feb 18, 2023
18e8adb
snapshot. trying to get wedges for regr classsplit
parrt Feb 18, 2023
09675ae
rename arg
parrt Feb 18, 2023
1c7e4b6
rename arg; cleanup
parrt Feb 19, 2023
5bd6858
play with example
parrt Feb 19, 2023
097107d
Fix the plot issue when a class label is not in a tree node. (#266)
tlapusan Feb 20, 2023
81f6553
Revert "For regressor + catvar split node, split might be on an unord…
parrt Feb 20, 2023
ffbeac6
don't need triangles anymore for regr + cat split
parrt Feb 20, 2023
cee2c20
Merge branch 'tf-catvars-from-tlapusan' of github.com:parrt/dtreeviz …
parrt Feb 20, 2023
3b7f05a
add X_train sanity check for NaN
parrt Feb 20, 2023
9040674
rm print
parrt Feb 20, 2023
414d2d2
If we're highlighting a node in classifier with cat split, x will be …
parrt Feb 20, 2023
aa8202f
play with example
parrt Feb 20, 2023
42d7b26
update tests and spark notebook
parrt Feb 20, 2023
3ff7210
Merge branch 'tf-catvars-from-tlapusan' of github.com:parrt/dtreeviz …
parrt Feb 20, 2023
b06a476
fix order of `class_names=["perish", "survive"]` in examples
parrt Feb 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .idea/animl.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 23 additions & 6 deletions dtreeviz/models/shadow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,21 +268,27 @@ def is_categorical_split(self, id) -> bool:
def get_split_node_heights(self, X_train, y_train, nbins) -> Mapping[int, int]:
class_values = np.unique(y_train)
node_heights = {}
# print(f"Goal {nbins} bins")
for node in self.internal:
# print(node.feature_name(), node.id)
# print(f"node feature {node.feature_name()}, id {node.id}")
X_feature = X_train[:, node.feature()]
overall_feature_range = (np.min(X_feature), np.max(X_feature))
# print(f"range {overall_feature_range}")
if node.is_categorical_split():
overall_feature_range = (0, len(np.unique(X_train[:, node.feature()])) - 1)
else:
overall_feature_range = (np.min(X_feature), np.max(X_feature))

bins = np.linspace(overall_feature_range[0],
overall_feature_range[1], nbins + 1)
# print(f"\tlen(bins)={len(bins):2d} bins={bins}")
X, y = X_feature[node.samples()], y_train[node.samples()]

# in case there is a categorical split node, we can convert the values to numbers because we need them
# only for getting the distribution values
if node.is_categorical_split():
X = pd.Series(X).astype("category").cat.codes

X_hist = [X[y == cl] for cl in class_values]
height_of_bins = np.zeros(nbins)
for i, _ in enumerate(class_values):
hist, foo = np.histogram(X_hist[i], bins=bins, range=overall_feature_range)
# print(f"class {cl}: goal_n={len(bins):2d} n={len(hist):2d} {hist}")
height_of_bins += hist
node_heights[node.id] = np.max(height_of_bins)
# print(f"\tmax={np.max(height_of_bins):2.0f}, heights={list(height_of_bins)}, {len(height_of_bins)} bins")
Expand Down Expand Up @@ -413,6 +419,17 @@ def _get_y_data(y_train):

@staticmethod
def get_shadow_tree(tree_model, X_train, y_train, feature_names, target_name, class_names=None, tree_index=None):
"""Get an internal representation of the tree obtained from a specific library"""
# Sanity check
if isinstance(X_train, pd.DataFrame):
nancols = X_train.columns[X_train.isnull().any().values].tolist()
if len(nancols)>0:
raise ValueError(f"dtreeviz does not support NaN (see column(s) {', '.join(nancols)})")
elif isinstance(X_train, np.ndarray):
nancols = np.where(pd.isnull(X_train).any(axis=0))[0].astype(str).tolist()
if len(nancols)>0:
raise ValueError(f"dtreeviz does not support NaN (see column index(es) {', '.join(nancols)})")

"""
To check to which library the tree_model belongs we are using string checks instead of isinstance()
because we don't want all the libraries to be installed as mandatory, except sklearn.
Expand Down
111 changes: 82 additions & 29 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from colour import Color, rgb2hex
from sklearn import tree

from dtreeviz.colors import adjust_colors
from dtreeviz.interpretation import explain_prediction_plain_english, explain_prediction_sklearn_default
from dtreeviz.models.shadow_decision_tree import ShadowDecTree
from dtreeviz.models.shadow_decision_tree import ShadowDecTreeNode
from dtreeviz.utils import myround, DTreeVizRender, add_classifier_legend, _format_axes, _draw_wedge, _set_wedge_ticks, tessellate
from dtreeviz.utils import myround, DTreeVizRender, add_classifier_legend, _format_axes, _draw_wedge, \
_set_wedge_ticks, tessellate, is_numeric

# How many bins should we have based upon number of classes
NUM_BINS = [
Expand Down Expand Up @@ -1100,14 +1102,26 @@ def _class_split_viz(node: ShadowDecTreeNode,
# keep the bar widths as uniform as possible for all node visualisations
nbins = nbins if nbins > feature_unique_size else feature_unique_size + 1

overall_feature_range = (np.min(X_train[:, node.feature()]), np.max(X_train[:, node.feature()]))
bins = np.linspace(start=overall_feature_range[0], stop=overall_feature_range[1], num=nbins, endpoint=True)
# only for str categorical features which are str type, int categorical features can work fine as numerical ones
if node.is_categorical_split() and type(X_feature[0]) == str:
# TODO think if the len() should be from all training[feature] data vs only data from this specific node ?
overall_feature_range = (0, len(np.unique(X_feature)) - 1)
else:
overall_feature_range = (np.min(X_train[:, node.feature()]), np.max(X_train[:, node.feature()]))

bins = np.linspace(start=overall_feature_range[0], stop=overall_feature_range[1], num=nbins, endpoint=True)
_format_axes(ax, feature_name, None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)

class_names = node.shadow_tree.class_names
class_values = node.shadow_tree.classes()
X_hist = [X_node_feature[y_train == cl] for cl in class_values]

# for multiclass examples, there could be scenarios where a node won't contain all the class value labels which will
# generate a matplotlib exception. To solve this, we need to filter only the class values which belong to a node and
# theirs corresponding colors.
X_colors = [colors[cl] for i, cl in enumerate(class_values) if len(X_hist[i]) > 0]
X_hist = [hist for hist in X_hist if len(hist) > 0]

if histtype == 'strip':
ax.yaxis.set_visible(False)
ax.spines['left'].set_visible(False)
Expand All @@ -1122,8 +1136,6 @@ def _class_split_viz(node: ShadowDecTreeNode,
ax.scatter(bucket, y_noise, alpha=alpha, marker='o', s=dot_w, c=colors[i],
edgecolors=colors['edge'], lw=.3)
else:
X_colors = [colors[cl] for cl in class_values]

hist, bins, barcontainers = ax.hist(X_hist,
color=X_colors,
align='mid',
Expand All @@ -1143,9 +1155,24 @@ def _class_split_viz(node: ShadowDecTreeNode,

ax.set_xlim(*overall_feature_range_wide)

wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_class=True, h=h, height_range=height_range, bins=bins)
if highlight_node:
_ = _draw_wedge(ax, x=X[node.feature()], node=node, color=colors['highlight'], is_class=True, h=h, height_range=height_range, bins=bins)
if node.is_categorical_split() and type(X_feature[0]) == str:
# run draw to refresh the figure to get the xticklabels
plt.draw()
node_split = list(map(str, node.split()))
# get the label text and its position from the figure
label_index = dict([(label.get_text(), label.get_position()[0]) for label in ax.get_xticklabels()])
# get tick positions, ignoring "out of dictionary" symbol added by tensorflow trees for "unknown symbol"
wedge_ticks_position = [label_index[split] for split in node_split if split in label_index]
wedge_ticks = _draw_wedge(ax, x=wedge_ticks_position, node=node, color=colors['wedge'], is_classifier=True, h=h,
height_range=height_range, bins=bins)
if highlight_node:
highlight_value = [label_index[X[node.feature()]]]
_ = _draw_wedge(ax, x=highlight_value, node=node, color=colors['highlight'], is_classifier=True, h=h,
height_range=height_range, bins=bins)
else:
wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=True, h=h, height_range=height_range, bins=bins)
if highlight_node:
_ = _draw_wedge(ax, x=X[node.feature()], node=node, color=colors['highlight'], is_classifier=True, h=h, height_range=height_range, bins=bins)

_set_wedge_ticks(ax, ax_ticks=list(overall_feature_range), wedge_ticks=wedge_ticks)

Expand Down Expand Up @@ -1201,21 +1228,24 @@ def _regr_split_viz(node: ShadowDecTreeNode,
fig, ax = plt.subplots(1, 1, figsize=figsize)

feature_name = node.feature_name()

_format_axes(ax, feature_name, target_name if node == node.shadow_tree.root else None, colors, fontsize=label_fontsize, fontname=fontname, ticks_fontsize=ticks_fontsize, grid=False, pad_for_wedge=True)
ax.set_ylim(y_range)

# Get X, y data for all samples associated with this node.
X_feature = X_train[:, node.feature()]
X_feature, y_train = X_feature[node.samples()], y_train[node.samples()]

overall_feature_range = (np.min(X_train[:, node.feature()]), np.max(X_train[:, node.feature()]))
# only for str categorical features which are str type, int categorical features can work fine as numerical ones
if node.is_categorical_split() and type(X_feature[0]) == str:
# TODO think if the len() should be from all training[feature] data vs only data from this specific node ?
overall_feature_range = (0, len(np.unique(X_feature)) - 1)
else:
overall_feature_range = (np.min(X_train[:, node.feature()]), np.max(X_train[:, node.feature()]))

ax.set_xlim(*overall_feature_range)
xmin, xmax = overall_feature_range
xr = xmax - xmin

if not node.is_categorical_split():

ax.scatter(X_feature, y_train, s=5, c=colors['scatter_marker'], alpha=colors['scatter_marker_alpha'], lw=.3)
left, right = node.split_samples()
left = y_train[left]
Expand All @@ -1228,10 +1258,10 @@ def _regr_split_viz(node: ShadowDecTreeNode,
ax.plot([split, overall_feature_range[1]], [np.mean(right), np.mean(right)], '--', color=colors['split_line'],
linewidth=1)

wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_class=False)
wedge_ticks = _draw_wedge(ax, x=node.split(), node=node, color=colors['wedge'], is_classifier=False)

if highlight_node:
_ = _draw_wedge(ax, x=X[node.feature()], node=node, color=colors['highlight'], is_class=False)
_ = _draw_wedge(ax, x=X[node.feature()], node=node, color=colors['highlight'], is_classifier=False)

_set_wedge_ticks(ax, ax_ticks=list(overall_feature_range), wedge_ticks=wedge_ticks)

Expand All @@ -1252,11 +1282,19 @@ def _regr_split_viz(node: ShadowDecTreeNode,
color=colors["categorical_split_right"],
linewidth=1)

if highlight_node:
_ = _draw_wedge(ax, x=X[node.feature()], node=node, color=colors['highlight'], is_class=False)
# no wedge ticks for categorical split, just the x_ticks in case the categorical value is not a string
# if it's a string, then the xticks label will be handled automatically by ax.scatter plot
if type(X_feature[0]) is not str:
ax.set_xticks(np.unique(np.concatenate((X_feature, np.asarray(overall_feature_range)))))

# no wedge ticks for categorical split
ax.set_xticks(np.unique(np.concatenate((X_feature, np.asarray(overall_feature_range)))))
if highlight_node:
highlight_value = X[node.feature()]
if type(X_feature[0]) is str:
plt.draw()
# get the label text and its position from the figure
label_index = dict([(label.get_text(), label.get_position()[0]) for label in ax.get_xticklabels()])
highlight_value = label_index[X[node.feature()]]
_ = _draw_wedge(ax, x=highlight_value, node=node, color=colors['highlight'], is_classifier=False)

if filename is not None:
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
Expand Down Expand Up @@ -1295,7 +1333,6 @@ def _regr_leaf_viz(node: ShadowDecTreeNode,
xycoords='axes fraction', textcoords='offset points',
fontsize=label_fontsize, fontname=fontname, color=colors['axis_label'])


mu = .5
sigma = .08
X = np.random.normal(mu, sigma, size=len(y))
Expand Down Expand Up @@ -1475,6 +1512,9 @@ def _ctreeviz_univar(shadow_tree,
color_map = {v: color_values[i] for i, v in enumerate(class_values)}
X_colors = [color_map[cl] for cl in class_values]

# if np.numeric(X_train[:,featidx])
if not is_numeric(X_train[:,featidx]):
raise ValueError(f"ctree_feature_space only supports numeric feature spaces")

_format_axes(ax, shadow_tree.feature_names[featidx], 'Count' if gtype=='barstacked' else None,
colors, fontsize, fontname, ticks_fontsize=ticks_fontsize, grid=False)
Expand Down Expand Up @@ -1572,15 +1612,8 @@ def _ctreeviz_bivar(shadow_tree, fontsize, ticks_fontsize, fontname, show,
color_values = colors['classes'][n_classes]
color_map = {v: color_values[i] for i, v in enumerate(class_values)}

if 'splits' in show:
for node, bbox in tessellation:
x = bbox[0]
y = bbox[1]
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
rect = patches.Rectangle((x, y), w, h, angle=0, linewidth=.3, alpha=colors['tessellation_alpha'],
edgecolor=colors['rect_edge'], facecolor=color_map[node.prediction()])
ax.add_patch(rect)
if not is_numeric(X_train[:,featidx[0]]) or not is_numeric(X_train[:,featidx[1]]):
raise ValueError(f"ctree_feature_space only supports numeric feature spaces")

dot_w = 25
X_hist = [X_train[y_train == cl] for cl in class_values]
Expand All @@ -1591,6 +1624,17 @@ def _ctreeviz_bivar(shadow_tree, fontsize, ticks_fontsize, fontname, show,
_format_axes(ax, shadow_tree.feature_names[featidx[0]], shadow_tree.feature_names[featidx[1]],
colors, fontsize, fontname, ticks_fontsize=ticks_fontsize, grid=False)

if 'splits' in show:
plt.draw()
for node, bbox in tessellation:
x = bbox[0]
y = bbox[1]
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
rect = patches.Rectangle((x, y), w, h, angle=0, linewidth=.3, alpha=colors['tessellation_alpha'],
edgecolor=colors['rect_edge'], facecolor=color_map[node.prediction()])
ax.add_patch(rect)

if 'legend' in show:
add_classifier_legend(ax, shadow_tree.class_names, class_values, color_map, shadow_tree.target_name, colors,
fontname=fontname)
Expand All @@ -1613,6 +1657,9 @@ def _rtreeviz_univar(shadow_tree, fontsize, ticks_fontsize, fontname, show,
if X_train is None or y_train is None:
raise ValueError(f"X_train and y_train must not be none")

if not is_numeric(X_train[:,featidx]):
raise ValueError(f"rtree_feature_space only supports numeric feature spaces")

if ax is None:
if figsize:
fig, ax = plt.subplots(figsize=figsize)
Expand Down Expand Up @@ -1691,6 +1738,9 @@ def _rtreeviz_bivar_heatmap(shadow_tree, fontsize, ticks_fontsize, fontname,
n_colors_in_map)]
featidx = [shadow_tree.feature_names.index(f) for f in features]

if not is_numeric(X_train[:,featidx[0]]) or not is_numeric(X_train[:,featidx[1]]):
raise ValueError(f"rtree_feature_space only supports numeric feature spaces")

tessellation = tessellate(shadow_tree.root, X_train, featidx)

for node, bbox in tessellation:
Expand Down Expand Up @@ -1755,8 +1805,11 @@ def plane(node, bbox, color_spectrum):
y_colors = [color_spectrum[y_to_color_index(y)] for y in y_train]

featidx = [shadow_tree.feature_names.index(f) for f in features]
x, y, z = X_train[:, featidx[0]], X_train[:, featidx[1]], y_train

if not is_numeric(X_train[:,featidx[0]]) or not is_numeric(X_train[:,featidx[1]]):
raise ValueError(f"rtree_feature_space3D only supports numeric feature spaces")

x, y, z = X_train[:, featidx[0]], X_train[:, featidx[1]], y_train
tessellation = tessellate(shadow_tree.root, X_train, featidx)

for node, bbox in tessellation:
Expand Down
22 changes: 17 additions & 5 deletions dtreeviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ def _format_axes(ax, xlabel, ylabel, colors, fontsize, fontname, ticks_fontsize=
ax.grid(visible=grid)


def _draw_wedge(ax, x, node, color, is_class, h=None, height_range=None, bins=None):

def _draw_wedge(ax, x, node, color, is_classifier, h=None, height_range=None, bins=None):
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
x_range = xmax - xmin
Expand All @@ -388,7 +387,7 @@ def _draw_tria(tip_x, tip_y, tri_width, tri_height):
ax.add_patch(t)
wedge_ticks.append(tip_x)

if is_class:
if is_classifier:
hr = h / (height_range[1] - height_range[0])
tri_height = y_range * .15 * 1 / hr # convert to graph coordinates (ugh)
tip_y = -0.1 * y_range * .15 * 1 / hr
Expand All @@ -397,7 +396,10 @@ def _draw_tria(tip_x, tip_y, tri_width, tri_height):
_draw_tria(x, tip_y, tri_width, tri_height)
else:
# classification: categorical split, draw multiple wedges
for split_value in node.split():
# If we're highlighting a node, x will be one value not multiple.
if np.size(x)==1:
x = [x] # normalize to a list even if one value
for split_value in x:
# to display the wedge exactly in the middle of the vertical bar
for bin_index in range(len(bins) - 1):
if bins[bin_index] <= split_value <= bins[bin_index + 1]:
Expand All @@ -408,7 +410,6 @@ def _draw_tria(tip_x, tip_y, tri_width, tri_height):
# regression
tri_height = y_range * .1
_draw_tria(x, ymin, tri_width, tri_height)

return wedge_ticks


Expand Down Expand Up @@ -443,6 +444,8 @@ def tessellate(root, X_train, featidx):
"""
Walk tree and return list of tuples containing a leaf node and bounding box list
of(x1, y1, x2, y2) coordinates.

Does not work for catvars!
"""
bboxes = [] # filled in by walk()
f1_values = X_train[:, featidx[0]]
Expand Down Expand Up @@ -474,3 +477,12 @@ def walk(t, bbox, nsplits):
walk(root, overall_bbox, 0)

return bboxes


def is_numeric(A:np.ndarray) -> bool:
try:
A.astype(float)
return True
except ValueError as e:
pass
return False
2 changes: 1 addition & 1 deletion dtreeviz/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
__version__ = '2.1.4'
__version__ = '2.2.0'
2 changes: 1 addition & 1 deletion notebooks/dtreeviz_lightgbm_visualisations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
"viz_model = dtreeviz.model(lgbm_model, tree_index=1,\n",
" X_train=dataset[features], y_train=dataset[target],\n",
" feature_names=features,\n",
" target_name=target, class_names=[\"survive\", \"perish\"])"
" target_name=target, class_names=[\"perish\", \"survive\"])"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/dtreeviz_sklearn_pipeline_visualisations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
"viz_model = dtreeviz.model(tree_classifier,\n",
" X_train=X_train, y_train=y_train,\n",
" feature_names=features_model,\n",
" target_name=target, class_names=[\"survive\", \"perish\"])"
" target_name=target, class_names=[\"perish\", \"survive\"])"
]
},
{
Expand Down
Loading