diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index a789472f1..df09e6616 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -2,10 +2,7 @@ # pyre-strict -import sys import typing -import warnings -from time import time from types import TracebackType from typing import ( Any, @@ -15,17 +12,13 @@ Iterator, Literal, Optional, - Sized, TextIO, Type, TypeVar, Union, ) -try: - from tqdm.auto import tqdm -except ImportError: - tqdm = None +from tqdm.auto import tqdm T = TypeVar("T") IterableType = TypeVar("IterableType") @@ -105,103 +98,15 @@ def close(self) -> None: pass -class SimpleProgress(Iterable[IterableType]): - def __init__( - self, - iterable: Optional[Iterable[IterableType]] = None, - desc: Optional[str] = None, - total: Optional[int] = None, - file: Optional[TextIO] = None, - mininterval: float = 0.5, - ) -> None: - """ - Simple progress output used when tqdm is unavailable. - Same as tqdm, output to stderr channel. - If you want to do nested Progressbars with simple progress - the parent progress bar should be used as a context - (i.e. with statement) and the nested progress bar should be - created inside this context. - """ - self.cur = 0 - self.iterable = iterable - self.total = total - if total is None and hasattr(iterable, "__len__"): - self.total = len(cast(Sized, iterable)) - - self.desc = desc - - file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr) - self.file: DisableErrorIOWrapper = file_wrapper - - self.mininterval = mininterval - self.last_print_t = 0.0 - self.closed = False - self._is_parent = False - - def __enter__(self) -> "SimpleProgress[IterableType]": - self._is_parent = True - self._refresh() - return self - - def __exit__( - self, - exc_type: Union[Type[BaseException], None], - exc_value: Union[BaseException, None], - exc_traceback: Union[TracebackType, None], - ) -> Literal[False]: - self.close() - return False - - def __iter__(self) -> Iterator[IterableType]: - if self.closed or not self.iterable: - return - self._refresh() - for it in cast(Iterable[IterableType], self.iterable): - yield it - self.update() - self.close() - - def _refresh(self) -> None: - progress_str = self.desc + ": " if self.desc else "" - if self.total: - # e.g., progress: 60% 3/5 - progress_str += ( - f"{100 * self.cur // cast(int, self.total)}%" - f" {self.cur}/{cast(int, self.total)}" - ) - else: - # e.g., progress: ..... - progress_str += "." * self.cur - end = "\n" if self._is_parent else "" - print("\r" + progress_str, end=end, file=self.file) - - def update(self, amount: int = 1) -> None: - if self.closed: - return - self.cur += amount - - cur_t = time() - if cur_t - self.last_print_t >= self.mininterval: - self._refresh() - self.last_print_t = cur_t - - def close(self) -> None: - if not self.closed and not self._is_parent: - self._refresh() - print(file=self.file) # end with new line - self.closed = True - - @typing.overload def progress( iterable: None = None, desc: Optional[str] = None, total: Optional[int] = None, - use_tqdm: bool = True, file: Optional[TextIO] = None, mininterval: float = 0.5, **kwargs: object, -) -> Union[SimpleProgress[None], tqdm]: ... +) -> tqdm: ... @typing.overload @@ -209,40 +114,25 @@ def progress( iterable: Iterable[IterableType], desc: Optional[str] = None, total: Optional[int] = None, - use_tqdm: bool = True, file: Optional[TextIO] = None, mininterval: float = 0.5, **kwargs: object, -) -> Union[SimpleProgress[IterableType], tqdm]: ... +) -> tqdm: ... def progress( iterable: Optional[Iterable[IterableType]] = None, desc: Optional[str] = None, total: Optional[int] = None, - use_tqdm: bool = True, file: Optional[TextIO] = None, mininterval: float = 0.5, **kwargs: object, -) -> Union[SimpleProgress[IterableType], tqdm]: - # Try to use tqdm is possible. Fall back to simple progress print - if tqdm and use_tqdm: - return tqdm( - iterable, - desc=desc, - total=total, - file=file, - mininterval=mininterval, - **kwargs, - ) - else: - if not tqdm and use_tqdm: - warnings.warn( - "Tried to show progress with tqdm " - "but tqdm is not installed. " - "Fall back to simply print out the progress.", - stacklevel=1, - ) - return SimpleProgress( - iterable, desc=desc, total=total, file=file, mininterval=mininterval - ) +) -> tqdm: + return tqdm( + iterable, + desc=desc, + total=total, + file=file, + mininterval=mininterval, + **kwargs, + ) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 95d4cfbb3..ab9b9f9c6 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -30,7 +30,7 @@ _run_forward, ) from captum._utils.exceptions import FeatureAblationFutureError -from captum._utils.progress import progress, SimpleProgress +from captum._utils.progress import progress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution from captum.attr._utils.common import ( @@ -41,10 +41,7 @@ from torch import dtype, Tensor from torch.futures import collect_all, Future -try: - from tqdm.auto import tqdm -except ImportError: - tqdm = None +from tqdm.auto import tqdm IterableType = TypeVar("IterableType") @@ -418,7 +415,7 @@ def _attribute_with_independent_feature_masks( formatted_feature_mask: Tuple[Tensor, ...], num_examples: int, perturbations_per_eval: int, - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + attr_progress: Optional[tqdm], initial_eval: Tensor, flattened_initial_eval: Tensor, n_outputs: int, @@ -500,7 +497,7 @@ def _attribute_with_cross_tensor_feature_masks( target: TargetType, baselines: BaselineType, formatted_feature_mask: Tuple[Tensor, ...], - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + attr_progress: Optional[tqdm], flattened_initial_eval: Tensor, initial_eval: Tensor, n_outputs: int, @@ -831,7 +828,7 @@ def _attribute_with_independent_feature_masks_future( baselines: BaselineType, formatted_feature_mask: Tuple[Tensor, ...], perturbations_per_eval: int, - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + attr_progress: Optional[tqdm], processed_initial_eval_fut: Future[ Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] ], @@ -942,7 +939,7 @@ def _attribute_with_cross_tensor_feature_masks_future( target: TargetType, baselines: BaselineType, formatted_feature_mask: Tuple[Tensor, ...], - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], + attr_progress: Optional[tqdm], processed_initial_eval_fut: Future[ Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] ], @@ -1105,7 +1102,6 @@ def _fut_tuple_to_accumulate_fut_list_cross_tensor( "_fut_tuple_to_accumulate_fut_list_cross_tensor failed" ) from e - # pyre-fixme[3] return type must be annotated def _attribute_progress_setup( self, formatted_inputs: Tuple[Tensor, ...], @@ -1113,7 +1109,7 @@ def _attribute_progress_setup( enable_cross_tensor_attribution: bool, perturbations_per_eval: int, **kwargs: Any, - ): + ) -> tqdm: feature_counts = self._get_feature_counts( formatted_inputs, feature_mask, **kwargs ) diff --git a/tests/utils/test_progress.py b/tests/utils/test_progress.py index a87b997f0..759f10683 100644 --- a/tests/utils/test_progress.py +++ b/tests/utils/test_progress.py @@ -41,26 +41,6 @@ def test_nested_progress_tqdm(self, mock_stderr) -> None: for item in parent_data: self.assertIn(f"test progress {item}:", output) - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) - def test_nested_simple_progress(self, mock_stderr) -> None: - parent_data = ["x", "y", "z"] - test_data = [1, 2, 3] - with progress( - parent_data, desc="parent progress", use_tqdm=False, mininterval=0.0 - ) as parent: - for item in parent: - for _ in progress( - test_data, desc=f"test progress {item}", use_tqdm=False - ): - pass - - output = mock_stderr.getvalue() - self.assertEqual( - output.count("parent progress:"), 5, "5 'parent' progress bar expected" - ) - for item in parent_data: - self.assertIn(f"test progress {item}:", output) - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) def test_progress_tqdm(self, mock_stderr) -> None: try: @@ -73,56 +53,3 @@ def test_progress_tqdm(self, mock_stderr) -> None: progressed = progress(test_data, desc="test progress") assert list(progressed) == test_data assert "test progress: " in mock_stderr.getvalue() - - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) - def test_simple_progress(self, mock_stderr) -> None: - test_data = [1, 3, 5] - desc = "test progress" - - progressed = progress(test_data, desc=desc, use_tqdm=False) - - assert list(progressed) == test_data - assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3") - assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n") - - # progress iterable without len but explicitly specify total - def gen(): - for n in test_data: - yield n - - mock_stderr.seek(0) - mock_stderr.truncate(0) - - progressed = progress(gen(), desc=desc, total=len(test_data), use_tqdm=False) - - assert list(progressed) == test_data - assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/3") - assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 3/3\n") - - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) - def test_simple_progress_without_total(self, mock_stderr) -> None: - test_data = [1, 3, 5] - desc = "test progress" - - def gen(): - for n in test_data: - yield n - - progressed = progress(gen(), desc=desc, use_tqdm=False) - - assert list(progressed) == test_data - assert mock_stderr.getvalue().startswith(f"\r{desc}: ") - assert mock_stderr.getvalue().endswith(f"\r{desc}: ...\n") - - @unittest.mock.patch("sys.stderr", new_callable=io.StringIO) - def test_simple_progress_update_manually(self, mock_stderr) -> None: - desc = "test progress" - - p = progress(total=5, desc=desc, use_tqdm=False) - p.update(0) - p.update(2) - p.update(2) - p.update(1) - p.close() - assert mock_stderr.getvalue().startswith(f"\r{desc}: 0% 0/5") - assert mock_stderr.getvalue().endswith(f"\r{desc}: 100% 5/5\n")