diff --git a/pyss3/__init__.py b/pyss3/__init__.py index f95620d..97ec790 100644 --- a/pyss3/__init__.py +++ b/pyss3/__init__.py @@ -2044,6 +2044,8 @@ def fit(self, x_train, y_train, n_grams=1, prep=True, leave_pbar=True): cats = sorted(list(set(y_train))) stime = time() + x_train, y_train = list(x_train), list(y_train) + x_train = [ "".join([ x_train[i] @@ -2092,6 +2094,7 @@ def predict_proba(self, x_test, prep=True, leave_pbar=True): if not self.__categories__: raise EmptyModelError + x_test = list(x_test) classify = self.classify return [ classify(x, sort=False) @@ -2151,7 +2154,7 @@ def predict( stime = time() Print.info("about to start classifying test documents", offset=1) classify = self.classify_label if not multilabel else self.classify_multilabel - + x_test = list(x_test) y_pred = [ classify(doc, def_cat=def_cat, labels=labels, prep=prep) for doc in tqdm(x_test, desc="Classification",