Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does pyGAM support multi-class classification? #196

Open
zhangxz1123 opened this issue Aug 27, 2018 · 9 comments · May be fixed by #213
Open

Does pyGAM support multi-class classification? #196

zhangxz1123 opened this issue Aug 27, 2018 · 9 comments · May be fixed by #213

Comments

@zhangxz1123
Copy link

No description provided.

@dswah
Copy link
Owner

dswah commented Aug 27, 2018

@zhangxz1123
wow you read my mind! i was just thinking about this yesterday!

no. currently pygam does not support multiclass classification.

however you can quickly extend pygam to do so.

if you have M classes, then you can train M one-vs-all models. when you predict, keep the label for the model that outputs the highest probability.

this is a little less efficient, but about the same as the softmax activation.

@jolespin
Copy link

I was thinking about this as well!

Would this work?
http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html

model__multiclass = OneVsRestClassifier(LogisticGAM())

model__multiclass
# OneVsRestClassifier(estimator=LogisticGAM(callbacks=['deviance', 'diffs', 'accuracy'],
#    fit_intercept=True, max_iter=100, terms='auto', tol=0.0001,
#    verbose=False),
#           n_jobs=1)

@dswah
Copy link
Owner

dswah commented Sep 25, 2018

@jolespin
oh wow, i forgot about that class!

right, does that work? i havent tried it out.

@jolespin
Copy link

@dswah I couldn't get it to work:

from pygam import LogisticGAM
from sklearn.ensemble import OneVsRestClassifier

base_estimator = LogisticGAM(n_splines=20)
ensemble = OneVsRestClassifier(base_estimator, n_jobs=1)
ensemble.fit(X_iris, y_iris)
model_selection.cross_val_score(ensemble, X=X_iris, y=y_iris, cv=10)
# ---------------------------------------------------------------------------
# AttributeError                            Traceback (most recent call last)
# ~/anaconda/envs/python3/lib/python3.6/site-packages/sklearn/multiclass.py in _predict_binary(estimator, X)
#      94     try:
# ---> 95         score = np.ravel(estimator.decision_function(X))
#      96     except (AttributeError, NotImplementedError):

# ~/anaconda/envs/python3/lib/python3.6/site-packages/pygam/terms.py in __getattr__(self, name)
#     978 
# --> 979         return self._super_get(name)
#     980 

# ~/anaconda/envs/python3/lib/python3.6/site-packages/pygam/terms.py in _super_get(self, name)
#     899     def _super_get(self, name):
# --> 900         return super(MetaTermMixin, self).__getattribute__(name)
#     901 

# AttributeError: 'LogisticGAM' object has no attribute 'decision_function'

@dswah
Copy link
Owner

dswah commented Sep 25, 2018

blah, well thats annoying.

@jolespin thank you for trying!

@dswah
Copy link
Owner

dswah commented Oct 3, 2018

this PR (#213) appears to do the trick

@jolespin
Copy link

jolespin commented Oct 3, 2018

Should this be a default for multiclassification case?

@dswah
Copy link
Owner

dswah commented Oct 3, 2018

@jolespin do you mean that pygam should import scikit-learn?

i think pygam should NOT import sklearn because it is such a big library with its various dependencies....

but then perhaps it doesn't make sense to add the decision_function method, since it assumes that users will use sklearn?

@kra268
Copy link

kra268 commented Nov 10, 2023

Since this is still open, is it a good approach to have your target split into 'n' 1vsAll class manually and fit 'n' GAMs? I understand it can be tedious for a large number of classes but for example, if you have 3 classes = [1,2,3], you can have 3 GAMs where each GAM is a binary classifier of class 'i' vs. rest=0. This is assuming that splitting your target is simple enough. You could do this manually instead of using OVR from sklearn.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants