Skip to content

Commit

Permalink
Fix non-list x_train/test argument in fit/predict
Browse files Browse the repository at this point in the history
  • Loading branch information
sergioburdisso committed Feb 26, 2020
1 parent 99a8a45 commit 5dbdc3a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,8 @@ def fit(self, x_train, y_train, n_grams=1, prep=True, leave_pbar=True):
cats = sorted(list(set(y_train)))
stime = time()

x_train, y_train = list(x_train), list(y_train)

x_train = [
"".join([
x_train[i]
Expand Down Expand Up @@ -2092,6 +2094,7 @@ def predict_proba(self, x_test, prep=True, leave_pbar=True):
if not self.__categories__:
raise EmptyModelError

x_test = list(x_test)
classify = self.classify
return [
classify(x, sort=False)
Expand Down Expand Up @@ -2151,7 +2154,7 @@ def predict(
stime = time()
Print.info("about to start classifying test documents", offset=1)
classify = self.classify_label if not multilabel else self.classify_multilabel

x_test = list(x_test)
y_pred = [
classify(doc, def_cat=def_cat, labels=labels, prep=prep)
for doc in tqdm(x_test, desc="Classification",
Expand Down

0 comments on commit 5dbdc3a

Please sign in to comment.