Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError when using decision_boundaries function #238

Closed
0ptimista opened this issue Jan 6, 2023 · 11 comments
Closed

KeyError when using decision_boundaries function #238

0ptimista opened this issue Jan 6, 2023 · 11 comments
Labels
Milestone

Comments

@0ptimista
Copy link

0ptimista commented Jan 6, 2023

I have a trained DecisionTreeClassifier model with 2 features. And it is good when using dtreeviz.model() to observe the model.

CleanShot 2023-01-06 at 15 00 12

But when I try decision_boundaries() It's throwing a KeyError and draw only decision boundaries without data points. I want thoses points:

Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?e5f87d4f-5132-477c-b55b-fd80eed7d948)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[34], line 1
----> 1 decision_boundaries(
      2     dt_wx,
      3     X_train,
      4     y_train,
      5     feature_names=list(X_train.columns),
      6     target_name='error_rate',
      7     class_names=["OK", "Problem"],
      8 )

File /opt/homebrew/Caskroom/miniconda/base/envs/data-science-py311/lib/python3.11/site-packages/dtreeviz/classifiers.py:113, in decision_boundaries(model, X, y, ntiles, tile_fraction, binary_threshold, show, feature_names, target_name, class_names, markers, boundary_marker, boundary_markersize, fontsize, fontname, dot_w, yshift, sigma, colors, ranges, figsize, ax)
     97     decision_boundaries_univar(model=model, x=X, y=y,
     98                                ntiles=ntiles,
     99                                binary_threshold=binary_threshold,
   (...)
    110                                figsize=figsize,
    111                                ax=ax)
    112 elif len(X.shape) == 2 and X.shape[1] == 2:
--> 113     decision_boundaries_bivar(model=model, X=X, y=y,
    114                               ntiles=ntiles, tile_fraction=tile_fraction,
    115                               binary_threshold=binary_threshold,
    116                               show=show,
    117                               feature_names=feature_names, target_name=target_name,
...
    207                lw=.5)
    208     # Show misclassified markers (can't have alpha per marker so do in 2 calls)
    209     bad_x = x_[class_X_pred[i] != class_values[i],:]

output

@parrt
Copy link
Owner

parrt commented Jan 6, 2023

Can you send data + small program? I can debug.

@0ptimista
Copy link
Author

Is it ok I send those to your email address on GitHub? Or is there a better way ?

@parrt
Copy link
Owner

parrt commented Jan 8, 2023

You can probably attach here if they’re not too big but my email is OK as well

@0ptimista
Copy link
Author

I tried this on Jupyter Notebook.

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from dtreeviz import decision_boundaries
import dtreeviz

data = pd.read_csv('sample.csv')
X=data.drop('stat',axis=1)
y=data['stat']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

dt_wx = DecisionTreeClassifier(max_depth=6)
dt_wx.fit(X.values, y.values)

viz = dtreeviz.model(
    dt_wx,
    X_train,
    y_train,
    feature_names=list(X_train.columns),
    target_name='stat',
    class_names=["OK", "Problem"],
)
viz.view(scale=1)

decision_boundaries(
    dt_wx, X_train, y_train,
    ntiles=40,
    tile_fraction=1,
    feature_names=list(X_train.columns),
    target_name='stat',
    class_names=["OK", "Problem"],

)

sample.csv

Thanks for helping Professor!

@tlapusan
Copy link
Collaborator

tlapusan commented Jan 8, 2023

@parrt just a hint, I made a little debug on the code and the error is generated because:

  1. y class value are [1 2]
  2. color_map is {1: '#FEFEBB', 2: '#a1dab4'}
  3. and at line 202 it try to get the c=color_map[i], where i = 0. The dict color_map doesn't contain the key 0.

@parrt
Copy link
Owner

parrt commented Jan 14, 2023

I think we make an assumption that class values all start from zero, right?

@tlapusan
Copy link
Collaborator

@parrt I guess yes, I am not very familiar with that part of implementation.

@parrt
Copy link
Owner

parrt commented Jan 15, 2023

OK @0ptimista, the issue is that class labels have to start from zero but the labels in this case are [1,2]. It must be very common to keep everything indexed from zero so for now I'm going to simply add code indicate this is an error.

@parrt
Copy link
Owner

parrt commented Jan 15, 2023

You can probably do something like y=data['stat']-1

@parrt
Copy link
Owner

parrt commented Jan 15, 2023

I am adding functionality to emit an error:

Traceback (most recent call last):
  File "/Users/parrt/github/dtreeviz/t2.py", line 24, in <module>
    viz.view(scale=1)
  File "/Users/parrt/github/dtreeviz/dtreeviz/trees.py", line 478, in view
    raise ValueError("Target label values (for now) must be 0..n-1 for n labels")
ValueError: Target label values (for now) must be 0..n-1 for n labels

parrt added a commit that referenced this issue Jan 15, 2023
… n classes.

Signed-off-by: Terence Parr <parrt@antlr.org>
@parrt parrt added the clean up label Jan 15, 2023
@parrt parrt added this to the 2.1 milestone Jan 15, 2023
@parrt parrt closed this as completed Jan 15, 2023
@0ptimista
Copy link
Author

@parrt @tlapusan I tried to set my class from 0 as sugestted, now I can see those points.

The new ValueError above it is really a good hint, and again, thanks for helping!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants