Skip to content

Commit

Permalink
Improve prog. bars compatiblity w/Jupyter Notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
sergioburdisso committed Feb 22, 2020
1 parent cebb8b2 commit 7848b3e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
3 changes: 2 additions & 1 deletion pyss3/cmd_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,8 @@ def do_grid_search(self, args):
Evaluation.grid_search(
CLF, x_data, y_data,
hparams["s"], hparams["l"], hparams["p"], hparams["a"],
k_fold, n_grams, def_cat, data_path, cache=cache
k_fold, n_grams, def_cat, data_path,
cache=cache, extended_pbar=True
)
Print.warn(
"Suggestion: use the command 'plot %s' to visualize the results"
Expand Down
46 changes: 25 additions & 21 deletions pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def __evaluation_result__(
@staticmethod
def __grid_search_loop__(
clf, x_test, y_test, ss, ll, pp, aa, k_fold,
i_fold, def_cat, tag, categories, cache=True, leave_pbar=True
i_fold, def_cat, tag, categories, cache=True, leave_pbar=True, extended_pbar=False
):
"""Grid search main loop."""
ss = [round_fix(s) for s in list_by_force(ss)]
Expand All @@ -713,7 +713,8 @@ def __grid_search_loop__(
)
progress_desc = tqdm(
total=0,
bar_format='{desc}', leave=leave_pbar
bar_format='{desc}', leave=leave_pbar,
disable=not extended_pbar
)

method = Evaluation.__kfold2method__(k_fold)
Expand Down Expand Up @@ -1339,7 +1340,7 @@ def kfold_cross_validation(
def grid_search(
clf, x_data, y_data, s=None, l=None, p=None, a=None,
k_fold=None, n_grams=None, def_cat=STR_MOST_PROBABLE, tag=None,
metric='accuracy', avg='macro', cache=True
metric='accuracy', avg='macro', cache=True, extended_pbar=False
):
"""
Perform a grid search using the provided hyperparameter values.
Expand Down Expand Up @@ -1437,6 +1438,9 @@ def grid_search(
to completely perform the evaluation ignoring cached values
(default: True).
:type cache: bool
:param extended_pbar: whether to show an extra status bar along with
the progress bar (default: False).
:type extended_pbar: bool
:returns: a tuple of hyperparameter values (s, l, p, a) with the best
values for the given metric
:rtype: tuple
Expand All @@ -1455,17 +1459,18 @@ def grid_search(

Evaluation.__set_last_evaluation__(tag, method, def_cat)

s = s or clf.get_s()
l = l or clf.get_l()
p = p or clf.get_p()
a = a or clf.get_a()
s = s if s is not None else clf.get_s()
l = l if l is not None else clf.get_l()
p = p if p is not None else clf.get_p()
a = a if a is not None else clf.get_a()

Print.show()
if not k_fold: # if test
x_test, y_test = x_data, [clf.get_category_index(y) for y in y_data]
Evaluation.__grid_search_loop__(
clf, x_test, y_test, s, l, p, a, 1, 0,
def_cat, tag, clf.get_categories(), cache
def_cat, tag, clf.get_categories(), cache,
extended_pbar=extended_pbar
)
else: # if k-fold
Print.verbosity_region_begin(VERBOSITY.NORMAL)
Expand All @@ -1490,7 +1495,8 @@ def grid_search(

Evaluation.__grid_search_loop__(
clf_fold, x_test, y_test, s, l, p, a, k_fold, i_fold,
def_cat, tag, categories, cache, leave_pbar=False
def_cat, tag, categories, cache,
leave_pbar=False, extended_pbar=extended_pbar
)

Evaluation.__cache_update__()
Expand Down Expand Up @@ -1547,22 +1553,20 @@ def load_from_files(data_path, folder_label=True, as_single_doc=False):

cat_info[cat] = len(docs)
else:

folders = listdir(data_path)
for item in tqdm(folders, desc=" Categories",
leave=False, disable=Print.is_quiet()):
item_path = path.join(data_path, item)
if not path.isfile(item_path):
cat_info[item] = 0
files = listdir(item_path)
for file in tqdm(files, desc=" Documents",
leave=False, disable=Print.is_quiet()):
file_path = path.join(item_path, file)
for icat, cat in enumerate(folders):
cat_path = path.join(data_path, cat)
if not path.isfile(cat_path):
cat_info[cat] = 0
files = listdir(cat_path)
pbar_desc = "Loading documents for '%s' [%d/%d]" % (cat, icat + 1, len(folders))
for file in tqdm(files, desc=pbar_desc, disable=Print.is_quiet()):
file_path = path.join(cat_path, file)
if path.isfile(file_path):
with open(file_path, "r", encoding=ENCODING) as ffile:
x_data.append(ffile.read())
y_data.append(item)
cat_info[item] += 1
y_data.append(cat)
cat_info[cat] += 1

Print.info("%d categories found" % len(cat_info))
for cat in cat_info:
Expand Down

0 comments on commit 7848b3e

Please sign in to comment.