Skip to content

Commit 156458e

Browse files
ArjunR2404facebook-github-bot
authored andcommitted
Starter Task 1: Ensuring TQDM Progress Bar Always Enabled
Summary: My task was to remove the redundant progress bar code now that tqdm is a dependency. The problem was that since tqdm is now a dependency, we don't need to make checks for whether or not tqdm is being used, and thus we can remove the simple progress print code in the progress function (as well as the import error try-except code at the very top of the file). I also removed the SimpleProgress class definition, as well as the use_tqdm parameter in the progress function. Since we are no longer using the SimpleProgress object anywhere, I removed it from wherever it was used in the codebase (feature_ablation.py, this will be submitted as another diff related to all changes made to feature_ablation.py). In the test_progress.py, I removed the tests corresponding to SimpleProgress. Reviewed By: cyrjano Differential Revision: D75814260
1 parent 7dac7f4 commit 156458e

File tree

2 files changed

+12
-191
lines changed

2 files changed

+12
-191
lines changed

captum/_utils/progress.py

Lines changed: 12 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
Union,
2323
)
2424

25-
try:
26-
from tqdm.auto import tqdm
27-
except ImportError:
28-
tqdm = None
25+
from tqdm.auto import tqdm
2926

3027
T = TypeVar("T")
3128
IterableType = TypeVar("IterableType")
@@ -105,144 +102,41 @@ def close(self) -> None:
105102
pass
106103

107104

108-
class SimpleProgress(Iterable[IterableType]):
109-
def __init__(
110-
self,
111-
iterable: Optional[Iterable[IterableType]] = None,
112-
desc: Optional[str] = None,
113-
total: Optional[int] = None,
114-
file: Optional[TextIO] = None,
115-
mininterval: float = 0.5,
116-
) -> None:
117-
"""
118-
Simple progress output used when tqdm is unavailable.
119-
Same as tqdm, output to stderr channel.
120-
If you want to do nested Progressbars with simple progress
121-
the parent progress bar should be used as a context
122-
(i.e. with statement) and the nested progress bar should be
123-
created inside this context.
124-
"""
125-
self.cur = 0
126-
self.iterable = iterable
127-
self.total = total
128-
if total is None and hasattr(iterable, "__len__"):
129-
self.total = len(cast(Sized, iterable))
130-
131-
self.desc = desc
132-
133-
file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr)
134-
self.file: DisableErrorIOWrapper = file_wrapper
135-
136-
self.mininterval = mininterval
137-
self.last_print_t = 0.0
138-
self.closed = False
139-
self._is_parent = False
140-
141-
def __enter__(self) -> "SimpleProgress[IterableType]":
142-
self._is_parent = True
143-
self._refresh()
144-
return self
145-
146-
def __exit__(
147-
self,
148-
exc_type: Union[Type[BaseException], None],
149-
exc_value: Union[BaseException, None],
150-
exc_traceback: Union[TracebackType, None],
151-
) -> Literal[False]:
152-
self.close()
153-
return False
154-
155-
def __iter__(self) -> Iterator[IterableType]:
156-
if self.closed or not self.iterable:
157-
return
158-
self._refresh()
159-
for it in cast(Iterable[IterableType], self.iterable):
160-
yield it
161-
self.update()
162-
self.close()
163-
164-
def _refresh(self) -> None:
165-
progress_str = self.desc + ": " if self.desc else ""
166-
if self.total:
167-
# e.g., progress: 60% 3/5
168-
progress_str += (
169-
f"{100 * self.cur // cast(int, self.total)}%"
170-
f" {self.cur}/{cast(int, self.total)}"
171-
)
172-
else:
173-
# e.g., progress: .....
174-
progress_str += "." * self.cur
175-
end = "\n" if self._is_parent else ""
176-
print("\r" + progress_str, end=end, file=self.file)
177-
178-
def update(self, amount: int = 1) -> None:
179-
if self.closed:
180-
return
181-
self.cur += amount
182-
183-
cur_t = time()
184-
if cur_t - self.last_print_t >= self.mininterval:
185-
self._refresh()
186-
self.last_print_t = cur_t
187-
188-
def close(self) -> None:
189-
if not self.closed and not self._is_parent:
190-
self._refresh()
191-
print(file=self.file) # end with new line
192-
self.closed = True
193-
194-
195105
@typing.overload
196106
def progress(
197107
iterable: None = None,
198108
desc: Optional[str] = None,
199109
total: Optional[int] = None,
200-
use_tqdm: bool = True,
201110
file: Optional[TextIO] = None,
202111
mininterval: float = 0.5,
203112
**kwargs: object,
204-
) -> Union[SimpleProgress[None], tqdm]: ...
113+
) -> tqdm: ...
205114

206115

207116
@typing.overload
208117
def progress(
209118
iterable: Iterable[IterableType],
210119
desc: Optional[str] = None,
211120
total: Optional[int] = None,
212-
use_tqdm: bool = True,
213121
file: Optional[TextIO] = None,
214122
mininterval: float = 0.5,
215123
**kwargs: object,
216-
) -> Union[SimpleProgress[IterableType], tqdm]: ...
124+
) -> tqdm: ...
217125

218126

219127
def progress(
220128
iterable: Optional[Iterable[IterableType]] = None,
221129
desc: Optional[str] = None,
222130
total: Optional[int] = None,
223-
use_tqdm: bool = True,
224131
file: Optional[TextIO] = None,
225132
mininterval: float = 0.5,
226133
**kwargs: object,
227-
) -> Union[SimpleProgress[IterableType], tqdm]:
228-
# Try to use tqdm is possible. Fall back to simple progress print
229-
if tqdm and use_tqdm:
230-
return tqdm(
231-
iterable,
232-
desc=desc,
233-
total=total,
234-
file=file,
235-
mininterval=mininterval,
236-
**kwargs,
237-
)
238-
else:
239-
if not tqdm and use_tqdm:
240-
warnings.warn(
241-
"Tried to show progress with tqdm "
242-
"but tqdm is not installed. "
243-
"Fall back to simply print out the progress.",
244-
stacklevel=1,
245-
)
246-
return SimpleProgress(
247-
iterable, desc=desc, total=total, file=file, mininterval=mininterval
248-
)
134+
) -> tqdm:
135+
return tqdm(
136+
iterable,
137+
desc=desc,
138+
total=total,
139+
file=file,
140+
mininterval=mininterval,
141+
**kwargs,
142+
)

tests/utils/test_progress.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,6 @@ def test_nested_progress_tqdm(self, mock_stderr) -> None:
4141
for item in parent_data:
4242
self.assertIn(f"test progress {item}:", output)
4343

44-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
45-
def test_nested_simple_progress(self, mock_stderr) -> None:
46-
parent_data = ["x", "y", "z"]
47-
test_data = [1, 2, 3]
48-
with progress(
49-
parent_data, desc="parent progress", use_tqdm=False, mininterval=0.0
50-
) as parent:
51-
for item in parent:
52-
for _ in progress(
53-
test_data, desc=f"test progress {item}", use_tqdm=False
54-
):
55-
pass
56-
57-
output = mock_stderr.getvalue()
58-
self.assertEqual(
59-
output.count("parent progress:"), 5, "5 'parent' progress bar expected"
60-
)
61-
for item in parent_data:
62-
self.assertIn(f"test progress {item}:", output)
63-
6444
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
6545
def test_progress_tqdm(self, mock_stderr) -> None:
6646
try:
@@ -73,56 +53,3 @@ def test_progress_tqdm(self, mock_stderr) -> None:
7353
progressed = progress(test_data, desc="test progress")
7454
assert list(progressed) == test_data
7555
assert "test progress: " in mock_stderr.getvalue()
76-
77-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
78-
def test_simple_progress(self, mock_stderr) -> None:
79-
test_data = [1, 3, 5]
80-
desc = "test progress"
81-
82-
progressed = progress(test_data, desc=desc, use_tqdm=False)
83-
84-
assert list(progressed) == test_data
85-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3")
86-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n")
87-
88-
# progress iterable without len but explicitly specify total
89-
def gen():
90-
for n in test_data:
91-
yield n
92-
93-
mock_stderr.seek(0)
94-
mock_stderr.truncate(0)
95-
96-
progressed = progress(gen(), desc=desc, total=len(test_data), use_tqdm=False)
97-
98-
assert list(progressed) == test_data
99-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3")
100-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n")
101-
102-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
103-
def test_simple_progress_without_total(self, mock_stderr) -> None:
104-
test_data = [1, 3, 5]
105-
desc = "test progress"
106-
107-
def gen():
108-
for n in test_data:
109-
yield n
110-
111-
progressed = progress(gen(), desc=desc, use_tqdm=False)
112-
113-
assert list(progressed) == test_data
114-
assert mock_stderr.getvalue().startswith(f"\r{desc}: ")
115-
assert mock_stderr.getvalue().endswith(f"\r{desc}: ...\n")
116-
117-
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
118-
def test_simple_progress_update_manually(self, mock_stderr) -> None:
119-
desc = "test progress"
120-
121-
p = progress(total=5, desc=desc, use_tqdm=False)
122-
p.update(0)
123-
p.update(2)
124-
p.update(2)
125-
p.update(1)
126-
p.close()
127-
assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/5")
128-
assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 5/5\n")

0 commit comments

Comments
 (0)