diff --git a/pyss3/cmd_line.py b/pyss3/cmd_line.py index 741fcd4..673de8f 100644 --- a/pyss3/cmd_line.py +++ b/pyss3/cmd_line.py @@ -727,7 +727,8 @@ def do_grid_search(self, args): Evaluation.grid_search( CLF, x_data, y_data, hparams["s"], hparams["l"], hparams["p"], hparams["a"], - k_fold, n_grams, def_cat, data_path, cache=cache + k_fold, n_grams, def_cat, data_path, + cache=cache, extended_pbar=True ) Print.warn( "Suggestion: use the command 'plot %s' to visualize the results" diff --git a/pyss3/util.py b/pyss3/util.py index e23ab84..0b67b40 100644 --- a/pyss3/util.py +++ b/pyss3/util.py @@ -698,7 +698,7 @@ def __evaluation_result__( @staticmethod def __grid_search_loop__( clf, x_test, y_test, ss, ll, pp, aa, k_fold, - i_fold, def_cat, tag, categories, cache=True, leave_pbar=True + i_fold, def_cat, tag, categories, cache=True, leave_pbar=True, extended_pbar=False ): """Grid search main loop.""" ss = [round_fix(s) for s in list_by_force(ss)] @@ -713,7 +713,8 @@ def __grid_search_loop__( ) progress_desc = tqdm( total=0, - bar_format='{desc}', leave=leave_pbar + bar_format='{desc}', leave=leave_pbar, + disable=not extended_pbar ) method = Evaluation.__kfold2method__(k_fold) @@ -1339,7 +1340,7 @@ def kfold_cross_validation( def grid_search( clf, x_data, y_data, s=None, l=None, p=None, a=None, k_fold=None, n_grams=None, def_cat=STR_MOST_PROBABLE, tag=None, - metric='accuracy', avg='macro', cache=True + metric='accuracy', avg='macro', cache=True, extended_pbar=False ): """ Perform a grid search using the provided hyperparameter values. @@ -1437,6 +1438,9 @@ def grid_search( to completely perform the evaluation ignoring cached values (default: True). :type cache: bool + :param extended_pbar: whether to show an extra status bar along with + the progress bar (default: False). + :type extended_pbar: bool :returns: a tuple of hyperparameter values (s, l, p, a) with the best values for the given metric :rtype: tuple @@ -1455,17 +1459,18 @@ def grid_search( Evaluation.__set_last_evaluation__(tag, method, def_cat) - s = s or clf.get_s() - l = l or clf.get_l() - p = p or clf.get_p() - a = a or clf.get_a() + s = s if s is not None else clf.get_s() + l = l if l is not None else clf.get_l() + p = p if p is not None else clf.get_p() + a = a if a is not None else clf.get_a() Print.show() if not k_fold: # if test x_test, y_test = x_data, [clf.get_category_index(y) for y in y_data] Evaluation.__grid_search_loop__( clf, x_test, y_test, s, l, p, a, 1, 0, - def_cat, tag, clf.get_categories(), cache + def_cat, tag, clf.get_categories(), cache, + extended_pbar=extended_pbar ) else: # if k-fold Print.verbosity_region_begin(VERBOSITY.NORMAL) @@ -1490,7 +1495,8 @@ def grid_search( Evaluation.__grid_search_loop__( clf_fold, x_test, y_test, s, l, p, a, k_fold, i_fold, - def_cat, tag, categories, cache, leave_pbar=False + def_cat, tag, categories, cache, + leave_pbar=False, extended_pbar=extended_pbar ) Evaluation.__cache_update__() @@ -1547,22 +1553,20 @@ def load_from_files(data_path, folder_label=True, as_single_doc=False): cat_info[cat] = len(docs) else: - folders = listdir(data_path) - for item in tqdm(folders, desc=" Categories", - leave=False, disable=Print.is_quiet()): - item_path = path.join(data_path, item) - if not path.isfile(item_path): - cat_info[item] = 0 - files = listdir(item_path) - for file in tqdm(files, desc=" Documents", - leave=False, disable=Print.is_quiet()): - file_path = path.join(item_path, file) + for icat, cat in enumerate(folders): + cat_path = path.join(data_path, cat) + if not path.isfile(cat_path): + cat_info[cat] = 0 + files = listdir(cat_path) + pbar_desc = "Loading documents for '%s' [%d/%d]" % (cat, icat + 1, len(folders)) + for file in tqdm(files, desc=pbar_desc, disable=Print.is_quiet()): + file_path = path.join(cat_path, file) if path.isfile(file_path): with open(file_path, "r", encoding=ENCODING) as ffile: x_data.append(ffile.read()) - y_data.append(item) - cat_info[item] += 1 + y_data.append(cat) + cat_info[cat] += 1 Print.info("%d categories found" % len(cat_info)) for cat in cat_info: