Skip to content

Commit 72d6d83

Browse files
ArjunR2404facebook-github-bot
authored andcommitted
Starter Task 1: Ensuring TQDM Progress Bar Always Enabled (pytorch#1575)
Summary: My task was to remove the redundant progress bar code now that tqdm is a dependency. The problem was that since tqdm is now a dependency, we don't need to make checks for whether or not tqdm is being used, and thus we can remove the simple progress print code in the progress function (as well as the import error try-except code at the very top of the file). I also removed the SimpleProgress class definition, as well as the use_tqdm parameter in the progress function. Since we are no longer using the SimpleProgress object anywhere, I removed it from wherever it was used in the codebase (feature_ablation.py). In the test_progress.py, I removed the tests corresponding to SimpleProgress. Reviewed By: cyrjano Differential Revision: D75814260
1 parent f8762d6 commit 72d6d83

File tree

3 files changed

+18
-204
lines changed

3 files changed

+18
-204
lines changed

captum/_utils/progress.py

Lines changed: 12 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
# pyre-strict
44

5-
import sys
65
import typing
7-
import warnings
8-
from time import time
96
from types import TracebackType
107
from typing import (
118
Any,
@@ -15,17 +12,13 @@
1512
Iterator,
1613
Literal,
1714
Optional,
18-
Sized,
1915
TextIO,
2016
Type,
2117
TypeVar,
2218
Union,
2319
)
2420

25-
try:
26-
from tqdm.auto import tqdm
27-
except ImportError:
28-
tqdm = None
21+
from tqdm.auto import tqdm
2922

3023
T = TypeVar("T")
3124
IterableType = TypeVar("IterableType")
@@ -105,144 +98,41 @@ def close(self) -> None:
10598
pass
10699

107100

108-
class SimpleProgress(Iterable[IterableType]):
109-
def __init__(
110-
self,
111-
iterable: Optional[Iterable[IterableType]] = None,
112-
desc: Optional[str] = None,
113-
total: Optional[int] = None,
114-
file: Optional[TextIO] = None,
115-
mininterval: float = 0.5,
116-
) -> None:
117-
"""
118-
Simple progress output used when tqdm is unavailable.
119-
Same as tqdm, output to stderr channel.
120-
If you want to do nested Progressbars with simple progress
121-
the parent progress bar should be used as a context
122-
(i.e. with statement) and the nested progress bar should be
123-
created inside this context.
124-
"""
125-
self.cur = 0
126-
self.iterable = iterable
127-
self.total = total
128-
if total is None and hasattr(iterable, "__len__"):
129-
self.total = len(cast(Sized, iterable))
130-
131-
self.desc = desc
132-
133-
file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr)
134-
self.file: DisableErrorIOWrapper = file_wrapper
135-
136-
self.mininterval = mininterval
137-
self.last_print_t = 0.0
138-
self.closed = False
139-
self._is_parent = False
140-
141-
def __enter__(self) -> "SimpleProgress[IterableType]":
142-
self._is_parent = True
143-
self._refresh()
144-
return self
145-
146-
def __exit__(
147-
self,
148-
exc_type: Union[Type[BaseException], None],
149-
exc_value: Union[BaseException, None],
150-
exc_traceback: Union[TracebackType, None],
151-
) -> Literal[False]:
152-
self.close()
153-
return False
154-
155-
def __iter__(self) -> Iterator[IterableType]:
156-
if self.closed or not self.iterable:
157-
return
158-
self._refresh()
159-
for it in cast(Iterable[IterableType], self.iterable):
160-
yield it
161-
self.update()
162-
self.close()
163-
164-
def _refresh(self) -> None:
165-
progress_str = self.desc + ": " if self.desc else ""
166-
if self.total:
167-
# e.g., progress: 60% 3/5
168-
progress_str += (
169-
f"{100 * self.cur // cast(int, self.total)}%"
170-
f" {self.cur}/{cast(int, self.total)}"
171-
)
172-
else:
173-
# e.g., progress: .....
174-
progress_str += "." * self.cur
175-
end = "\n" if self._is_parent else ""
176-
print("\r" + progress_str, end=end, file=self.file)
177-
178-
def update(self, amount: int = 1) -> None:
179-
if self.closed:
180-
return
181-
self.cur += amount
182-
183-
cur_t = time()
184-
if cur_t - self.last_print_t >= self.mininterval:
185-
self._refresh()
186-
self.last_print_t = cur_t
187-
188-
def close(self) -> None:
189-
if not self.closed and not self._is_parent:
190-
self._refresh()
191-
print(file=self.file) # end with new line
192-
self.closed = True
193-
194-
195101
@typing.overload
196102
def progress(
197103
iterable: None = None,
198104
desc: Optional[str] = None,
199105
total: Optional[int] = None,
200-
use_tqdm: bool = True,
201106
file: Optional[TextIO] = None,
202107
mininterval: float = 0.5,
203108
**kwargs: object,
204-
) -> Union[SimpleProgress[None], tqdm]: ...
109+
) -> tqdm: ...
205110

206111

207112
@typing.overload
208113
def progress(
209114
iterable: Iterable[IterableType],
210115
desc: Optional[str] = None,
211116
total: Optional[int] = None,
212-
use_tqdm: bool = True,
213117
file: Optional[TextIO] = None,
214118
mininterval: float = 0.5,
215119
**kwargs: object,
216-
) -> Union[SimpleProgress[IterableType], tqdm]: ...
120+
) -> tqdm: ...
217121

218122

219123
def progress(
220124
iterable: Optional[Iterable[IterableType]] = None,
221125
desc: Optional[str] = None,
222126
total: Optional[int] = None,
223-
use_tqdm: bool = True,
224127
file: Optional[TextIO] = None,
225128
mininterval: float = 0.5,
226129
**kwargs: object,
227-
) -> Union[SimpleProgress[IterableType], tqdm]:
228-
# Try to use tqdm is possible. Fall back to simple progress print
229-
if tqdm and use_tqdm:
230-
return tqdm(
231-
iterable,
232-
desc=desc,
233-
total=total,
234-
file=file,
235-
mininterval=mininterval,
236-
**kwargs,
237-
)
238-
else:
239-
if not tqdm and use_tqdm:
240-
warnings.warn(
241-
"Tried to show progress with tqdm "
242-
"but tqdm is not installed. "
243-
"Fall back to simply print out the progress.",
244-
stacklevel=1,
245-
)
246-
return SimpleProgress(
247-
iterable, desc=desc, total=total, file=file, mininterval=mininterval
248-
)
130+
) -> tqdm:
131+
return tqdm(
132+
iterable,
133+
desc=desc,
134+
total=total,
135+
file=file,
136+
mininterval=mininterval,
137+
**kwargs,
138+
)

captum/attr/_core/feature_ablation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_run_forward,
3131
)
3232
from captum._utils.exceptions import FeatureAblationFutureError
33-
from captum._utils.progress import progress, SimpleProgress
33+
from captum._utils.progress import progress
3434
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
3535
from captum.attr._utils.attribution import PerturbationAttribution
3636
from captum.attr._utils.common import (
@@ -41,10 +41,7 @@
4141
from torch import dtype, Tensor
4242
from torch.futures import collect_all, Future
4343

44-
try:
45-
from tqdm.auto import tqdm
46-
except ImportError:
47-
tqdm = None
44+
from tqdm.auto import tqdm
4845

4946
IterableType = TypeVar("IterableType")
5047

@@ -418,7 +415,7 @@ def _attribute_with_independent_feature_masks(
418415
formatted_feature_mask: Tuple[Tensor, ...],
419416
num_examples: int,
420417
perturbations_per_eval: int,
421-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
418+
attr_progress: Optional[tqdm],
422419
initial_eval: Tensor,
423420
flattened_initial_eval: Tensor,
424421
n_outputs: int,
@@ -500,7 +497,7 @@ def _attribute_with_cross_tensor_feature_masks(
500497
target: TargetType,
501498
baselines: BaselineType,
502499
formatted_feature_mask: Tuple[Tensor, ...],
503-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
500+
attr_progress: Optional[tqdm],
504501
flattened_initial_eval: Tensor,
505502
initial_eval: Tensor,
506503
n_outputs: int,
@@ -831,7 +828,7 @@ def _attribute_with_independent_feature_masks_future(
831828
baselines: BaselineType,
832829
formatted_feature_mask: Tuple[Tensor, ...],
833830
perturbations_per_eval: int,
834-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
831+
attr_progress: Optional[tqdm],
835832
processed_initial_eval_fut: Future[
836833
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
837834
],
@@ -942,7 +939,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
942939
target: TargetType,
943940
baselines: BaselineType,
944941
formatted_feature_mask: Tuple[Tensor, ...],
945-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
942+
attr_progress: Optional[tqdm],
946943
processed_initial_eval_fut: Future[
947944
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
948945
],

tests/utils/test_progress.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,6 @@ def test_nested_progress_tqdm(self, mock_stderr) -> None:
4141
for item in parent_data:
4242
self.assertIn(f"test progress {item}:", output)
4343

44-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
45-
def test_nested_simple_progress(self, mock_stderr) -> None:
46-
parent_data = ["x", "y", "z"]
47-
test_data = [1, 2, 3]
48-
with progress(
49-
parent_data, desc="parent progress", use_tqdm=False, mininterval=0.0
50-
) as parent:
51-
for item in parent:
52-
for _ in progress(
53-
test_data, desc=f"test progress {item}", use_tqdm=False
54-
):
55-
pass
56-
57-
output = mock_stderr.getvalue()
58-
self.assertEqual(
59-
output.count("parent progress:"), 5, "5 'parent' progress bar expected"
60-
)
61-
for item in parent_data:
62-
self.assertIn(f"test progress {item}:", output)
63-
6444
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
6545
def test_progress_tqdm(self, mock_stderr) -> None:
6646
try:
@@ -73,56 +53,3 @@ def test_progress_tqdm(self, mock_stderr) -> None:
7353
progressed = progress(test_data, desc="test progress")
7454
assert list(progressed) == test_data
7555
assert "test progress: " in mock_stderr.getvalue()
76-
77-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
78-
def test_simple_progress(self, mock_stderr) -> None:
79-
test_data = [1, 3, 5]
80-
desc = "test progress"
81-
82-
progressed = progress(test_data, desc=desc, use_tqdm=False)
83-
84-
assert list(progressed) == test_data
85-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3")
86-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n")
87-
88-
# progress iterable without len but explicitly specify total
89-
def gen():
90-
for n in test_data:
91-
yield n
92-
93-
mock_stderr.seek(0)
94-
mock_stderr.truncate(0)
95-
96-
progressed = progress(gen(), desc=desc, total=len(test_data), use_tqdm=False)
97-
98-
assert list(progressed) == test_data
99-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3")
100-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n")
101-
102-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
103-
def test_simple_progress_without_total(self, mock_stderr) -> None:
104-
test_data = [1, 3, 5]
105-
desc = "test progress"
106-
107-
def gen():
108-
for n in test_data:
109-
yield n
110-
111-
progressed = progress(gen(), desc=desc, use_tqdm=False)
112-
113-
assert list(progressed) == test_data
114-
assert mock_stderr.getvalue().startswith(f"\r{desc}: ")
115-
assert mock_stderr.getvalue().endswith(f"\r{desc}: ...\n")
116-
117-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
118-
def test_simple_progress_update_manually(self, mock_stderr) -> None:
119-
desc = "test progress"
120-
121-
p = progress(total=5, desc=desc, use_tqdm=False)
122-
p.update(0)
123-
p.update(2)
124-
p.update(2)
125-
p.update(1)
126-
p.close()
127-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/5")
128-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 5/5\n")

0 commit comments

Comments
 (0)