Skip to content

Commit

Permalink
for #330, use tqdm.auto in trainer (#752)
Browse files Browse the repository at this point in the history
* use tqdm.auto in trainer

This will import the ipywidgets version of tqdm if available. This works nicely in notebooks by not filling up the log.

In the terminal it will use the same old tqdm.

We might also want to consider passing in the tqdm we want as an argument since there may be some edge cases where ipywidgets is available but the interface doesn't support it (e.g. vscode?) or isn't working. In which case people will get a warning message, but may want to configure it themselves.

* use `from tqdm.auto` in eval loop

* indents
  • Loading branch information
wassname authored and williamFalcon committed Jan 26, 2020
1 parent 7deec2c commit deffbab
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
from abc import ABC, abstractmethod

import torch
import tqdm
from tqdm.auto import tqdm

from pytorch_lightning.utilities.debugging import MisconfigurationException

Expand Down Expand Up @@ -293,9 +293,9 @@ def run_evaluation(self, test=False):
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
unit='batch', file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

# run evaluation
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import tqdm
from tqdm.auto import tqdm
from torch.optim.optimizer import Optimizer

from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
Expand Down Expand Up @@ -850,13 +850,13 @@ def run_pretrain_routine(self, model):
ref_model.on_train_start()
if not self.disable_validation and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check',
pbar = tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
self.main_progress_bar = pbar
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)
self.val_progress_bar = tqdm(disable=True)

eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.num_sanity_val_steps, False)
Expand All @@ -870,9 +870,9 @@ def run_pretrain_routine(self, model):
self.early_stop_callback.check_metrics(callback_metrics)

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
file=sys.stdout)
pbar = tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
file=sys.stdout)
self.main_progress_bar = pbar

# clear cache before training
Expand Down

0 comments on commit deffbab

Please sign in to comment.