Skip to content

Commit 752fc1d

Browse files
ArjunR2404facebook-github-bot
authored andcommitted
Starter Task 2: Adding return type annotation to the _attribute_progress_setup method (#1576)
Summary: The problem was that the _attribute_progress_setup method in feature_ablation.py didn't have a return type annotation. After making the relevant changes in the task 1 diff to only have the progress function return a tqdm object, I added a tqdm return type annotation to the _attribute_progress_setup function. Differential Revision: D75972827
1 parent f8762d6 commit 752fc1d

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_run_forward,
3131
)
3232
from captum._utils.exceptions import FeatureAblationFutureError
33-
from captum._utils.progress import progress, SimpleProgress
33+
from captum._utils.progress import progress
3434
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
3535
from captum.attr._utils.attribution import PerturbationAttribution
3636
from captum.attr._utils.common import (
@@ -41,10 +41,7 @@
4141
from torch import dtype, Tensor
4242
from torch.futures import collect_all, Future
4343

44-
try:
45-
from tqdm.auto import tqdm
46-
except ImportError:
47-
tqdm = None
44+
from tqdm.auto import tqdm
4845

4946
IterableType = TypeVar("IterableType")
5047

@@ -418,7 +415,7 @@ def _attribute_with_independent_feature_masks(
418415
formatted_feature_mask: Tuple[Tensor, ...],
419416
num_examples: int,
420417
perturbations_per_eval: int,
421-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
418+
attr_progress: Optional[tqdm],
422419
initial_eval: Tensor,
423420
flattened_initial_eval: Tensor,
424421
n_outputs: int,
@@ -500,7 +497,7 @@ def _attribute_with_cross_tensor_feature_masks(
500497
target: TargetType,
501498
baselines: BaselineType,
502499
formatted_feature_mask: Tuple[Tensor, ...],
503-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
500+
attr_progress: Optional[tqdm],
504501
flattened_initial_eval: Tensor,
505502
initial_eval: Tensor,
506503
n_outputs: int,
@@ -831,7 +828,7 @@ def _attribute_with_independent_feature_masks_future(
831828
baselines: BaselineType,
832829
formatted_feature_mask: Tuple[Tensor, ...],
833830
perturbations_per_eval: int,
834-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
831+
attr_progress: Optional[tqdm],
835832
processed_initial_eval_fut: Future[
836833
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
837834
],
@@ -942,7 +939,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
942939
target: TargetType,
943940
baselines: BaselineType,
944941
formatted_feature_mask: Tuple[Tensor, ...],
945-
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
942+
attr_progress: Optional[tqdm],
946943
processed_initial_eval_fut: Future[
947944
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
948945
],
@@ -1105,15 +1102,14 @@ def _fut_tuple_to_accumulate_fut_list_cross_tensor(
11051102
"_fut_tuple_to_accumulate_fut_list_cross_tensor failed"
11061103
) from e
11071104

1108-
# pyre-fixme[3] return type must be annotated
11091105
def _attribute_progress_setup(
11101106
self,
11111107
formatted_inputs: Tuple[Tensor, ...],
11121108
feature_mask: Tuple[Tensor, ...],
11131109
enable_cross_tensor_attribution: bool,
11141110
perturbations_per_eval: int,
11151111
**kwargs: Any,
1116-
):
1112+
) -> tqdm:
11171113
feature_counts = self._get_feature_counts(
11181114
formatted_inputs, feature_mask, **kwargs
11191115
)

0 commit comments

Comments
 (0)