|
30 | 30 | _run_forward,
|
31 | 31 | )
|
32 | 32 | from captum._utils.exceptions import FeatureAblationFutureError
|
33 |
| -from captum._utils.progress import progress, SimpleProgress |
| 33 | +from captum._utils.progress import progress |
34 | 34 | from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
|
35 | 35 | from captum.attr._utils.attribution import PerturbationAttribution
|
36 | 36 | from captum.attr._utils.common import (
|
|
41 | 41 | from torch import dtype, Tensor
|
42 | 42 | from torch.futures import collect_all, Future
|
43 | 43 |
|
44 |
| -try: |
45 |
| - from tqdm.auto import tqdm |
46 |
| -except ImportError: |
47 |
| - tqdm = None |
| 44 | +from tqdm.auto import tqdm |
48 | 45 |
|
49 | 46 | IterableType = TypeVar("IterableType")
|
50 | 47 |
|
@@ -418,7 +415,7 @@ def _attribute_with_independent_feature_masks(
|
418 | 415 | formatted_feature_mask: Tuple[Tensor, ...],
|
419 | 416 | num_examples: int,
|
420 | 417 | perturbations_per_eval: int,
|
421 |
| - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], |
| 418 | + attr_progress: Optional[tqdm], |
422 | 419 | initial_eval: Tensor,
|
423 | 420 | flattened_initial_eval: Tensor,
|
424 | 421 | n_outputs: int,
|
@@ -500,7 +497,7 @@ def _attribute_with_cross_tensor_feature_masks(
|
500 | 497 | target: TargetType,
|
501 | 498 | baselines: BaselineType,
|
502 | 499 | formatted_feature_mask: Tuple[Tensor, ...],
|
503 |
| - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], |
| 500 | + attr_progress: Optional[tqdm], |
504 | 501 | flattened_initial_eval: Tensor,
|
505 | 502 | initial_eval: Tensor,
|
506 | 503 | n_outputs: int,
|
@@ -831,7 +828,7 @@ def _attribute_with_independent_feature_masks_future(
|
831 | 828 | baselines: BaselineType,
|
832 | 829 | formatted_feature_mask: Tuple[Tensor, ...],
|
833 | 830 | perturbations_per_eval: int,
|
834 |
| - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], |
| 831 | + attr_progress: Optional[tqdm], |
835 | 832 | processed_initial_eval_fut: Future[
|
836 | 833 | Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
|
837 | 834 | ],
|
@@ -942,7 +939,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
|
942 | 939 | target: TargetType,
|
943 | 940 | baselines: BaselineType,
|
944 | 941 | formatted_feature_mask: Tuple[Tensor, ...],
|
945 |
| - attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], |
| 942 | + attr_progress: Optional[tqdm], |
946 | 943 | processed_initial_eval_fut: Future[
|
947 | 944 | Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
|
948 | 945 | ],
|
@@ -1105,15 +1102,14 @@ def _fut_tuple_to_accumulate_fut_list_cross_tensor(
|
1105 | 1102 | "_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
|
1106 | 1103 | ) from e
|
1107 | 1104 |
|
1108 |
| - # pyre-fixme[3] return type must be annotated |
1109 | 1105 | def _attribute_progress_setup(
|
1110 | 1106 | self,
|
1111 | 1107 | formatted_inputs: Tuple[Tensor, ...],
|
1112 | 1108 | feature_mask: Tuple[Tensor, ...],
|
1113 | 1109 | enable_cross_tensor_attribution: bool,
|
1114 | 1110 | perturbations_per_eval: int,
|
1115 | 1111 | **kwargs: Any,
|
1116 |
| - ): |
| 1112 | + ) -> tqdm: |
1117 | 1113 | feature_counts = self._get_feature_counts(
|
1118 | 1114 | formatted_inputs, feature_mask, **kwargs
|
1119 | 1115 | )
|
|
0 commit comments