Skip to content

Commit

Permalink
improve progress with warning & error warpper
Browse files Browse the repository at this point in the history
  • Loading branch information
aobo-y committed Mar 12, 2021
1 parent 4872fc1 commit a1a3098
Showing 1 changed file with 64 additions and 14 deletions.
78 changes: 64 additions & 14 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,95 @@
#!/usr/bin/env python3

import sys
from typing import Iterable
import warnings
from typing import Iterable, Sized, TextIO, cast

try:
from tqdm import tqdm
except ImportError:
tqdm = None


def _simple_progress_out(iterable: Iterable, desc: str = None, total: int = None):
class DisableErrorIOWrapper(object):
def __init__(self, wrapped: TextIO):
"""
The wrapper around a TextIO object to ignore write errors like tqdm
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
"""
self._wrapped = wrapped

def __getattr__(self, name):
return getattr(self._wrapped, name)

@staticmethod
def _wrapped_run(func, *args, **kwargs):
try:
return func(*args, **kwargs)
except OSError as e:
if e.errno != 5:
raise
except ValueError as e:
if "closed" not in str(e):
raise

def write(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.write, *args, **kwargs)

def flush(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)


def _simple_progress_out(
iterable: Iterable, desc: str = None, total: int = None, file: TextIO = None
):
"""
Simple progress output used when tqdm is unavailable.
Same as tqdm, output to stderr channel
"""
cur = 0

if total is None and hasattr(iterable, "__len__"):
total = len(iterable)
total = len(cast(Sized, iterable))

desc = desc + ": " if desc else ""
progress_str = (
lambda cur: f"{desc}{100 * cur // total}% {cur}/{total}"
if total
else f"{desc}{'.' * cur}"
)

print("\r" + progress_str(cur), end="", file=sys.stderr)
def _progress_str(cur):
if total:
# e.g., progress: 60% 3/5
return f"{desc}{100 * cur // total}% {cur}/{total}"
else:
# e.g., progress: .....
return f"{desc}{'.' * cur}"

if not file:
file = sys.stderr
file = DisableErrorIOWrapper(file)

print("\r" + _progress_str(cur), end="", file=file)
for it in iterable:
yield it
cur += 1
print("\r" + progress_str(cur), end="", file=sys.stderr)
print("\r" + _progress_str(cur), end="", file=file)

print(file=sys.stderr) # end with new line
print(file=file) # end with new line


def progress(
iterable: Iterable, desc: str = None, total: int = None, use_tqdm=True, **kwargs
iterable: Iterable,
desc: str = None,
total: int = None,
use_tqdm=True,
file: TextIO = None,
**kwargs,
):
# Try to use tqdm is possible. Fall back to simple progress print
if tqdm and use_tqdm:
return tqdm(iterable, desc=desc, total=total, **kwargs)
return tqdm(iterable, desc=desc, total=total, file=file, **kwargs)
else:
return _simple_progress_out(iterable, desc=desc, total=total)
if not tqdm and use_tqdm:
warnings.warn(
"Tried to show progress with tqdm "
"but tqdm is not installed. "
"Fall back to simply print out the progress."
)
return _simple_progress_out(iterable, desc=desc, total=total, file=file)

0 comments on commit a1a3098

Please sign in to comment.