From 70ce97e6bffc0ec029ab4f2eeb18fc75160ab3ef Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Tue, 3 Jun 2025 20:57:05 -0700 Subject: [PATCH 1/2] Fix OSS mypy and pyre and conda-related testing and typing issues (#1573) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1573 As title. Mypy decided to update again. As did pyre. And somehow rust compiler may be needed so I added it. Issues in https://github.com/pytorch/captum/actions/runs/15359553880/job/43224685521 hopefully will just disappear, it is due to a wheel not being found and failing to build a wheel as rust compiler isn't installed, but a wheel should definitely be available. Differential Revision: D75896598 Reviewed By: cyrjano --- .pyre_configuration | 3 ++- captum/metrics/_core/infidelity.py | 6 ++---- captum/metrics/_core/sensitivity.py | 11 +++-------- setup.py | 2 +- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/.pyre_configuration b/.pyre_configuration index 28e3997ff..f48daf636 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -13,5 +13,6 @@ "source_directories": [ "." ], - "strict": true + "strict": true, + "version": "0.0.101745838703" } diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index 203526752..316266e7c 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -499,8 +499,6 @@ def _generate_perturbations( repeated instances per example. """ - # pyre-fixme[53]: Captured variable `baselines_expanded` is not annotated. - # pyre-fixme[53]: Captured variable `inputs_expanded` is not annotated. def call_perturb_func() -> ( Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] ): @@ -520,12 +518,12 @@ def call_perturb_func() -> ( else perturb_func(inputs_pert) ) - inputs_expanded = tuple( + inputs_expanded: Tuple[Tensor, ...] = tuple( torch.repeat_interleave(input, current_n_perturb_samples, dim=0) for input in inputs ) - baselines_expanded = baselines + baselines_expanded: BaselineTupleType = baselines if baselines is not None: baselines_expanded = tuple( ( diff --git a/captum/metrics/_core/sensitivity.py b/captum/metrics/_core/sensitivity.py index bd925327e..a94d515fa 100644 --- a/captum/metrics/_core/sensitivity.py +++ b/captum/metrics/_core/sensitivity.py @@ -62,8 +62,7 @@ def default_perturb_func( @log_usage(part_of_slo=True) def sensitivity_max( - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - explanation_func: Callable, + explanation_func: Callable[..., TensorOrTupleOfTensorsGeneric], inputs: TensorOrTupleOfTensorsGeneric, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. perturb_func: Callable = default_perturb_func, @@ -232,8 +231,6 @@ def max_values(input_tnsr: Tensor) -> Tensor: # pyre-fixme[33]: Given annotation cannot be `Any`. kwargs_copy: Any = None - # pyre-fixme[53]: Captured variable `bsz` is not annotated. - # pyre-fixme[53]: Captured variable `expl_inputs` is not annotated. def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: inputs_perturbed = _generate_perturbations(current_n_perturb_samples) @@ -281,8 +278,6 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: [ (expl_input - expl_perturbed).view(expl_perturbed.size(0), -1) for expl_perturbed, expl_input in zip( - # pyre-fixme[6]: For 1st argument expected - # `Iterable[Variable[_T1]]` but got `None`. expl_perturbed_inputs, expl_inputs_expanded, ) @@ -318,10 +313,10 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: inputs = _format_tensor_into_tuples(inputs) # type: ignore - bsz = inputs[0].size(0) + bsz: int = inputs[0].size(0) with torch.no_grad(): - expl_inputs = explanation_func(inputs, **kwargs) + expl_inputs: TensorOrTupleOfTensorsGeneric = explanation_func(inputs, **kwargs) metrics_max = _divide_and_aggregate_metrics( cast(Tuple[Tensor, ...], inputs), n_perturb_samples, diff --git a/setup.py b/setup.py index 1d7614a84..bb1126589 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ def report(*args): "sphinx-autodoc-typehints", "sphinxcontrib-katex", "mypy>=0.760", - "pyre-check-nightly", + "pyre-check-nightly==0.0.101745838703", "usort==1.0.2", "ufmt", "scikit-learn", From 0f51eaf08c70486a5897b6e3f48d34f09c8fcb47 Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Tue, 3 Jun 2025 21:02:15 -0700 Subject: [PATCH 2/2] Convert LLM Attribution result to a proper dataclass (#1572) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1572 Converts LLM attribution result to a proper dataclass. Reviewed By: aobo-y Differential Revision: D75727427 --- captum/attr/_core/llm_attr.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 47056d1fa..7c124b727 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -3,11 +3,9 @@ import warnings from abc import ABC - from copy import copy - +from dataclasses import dataclass from textwrap import shorten - from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import matplotlib.colors as mcolors @@ -45,6 +43,7 @@ } +@dataclass class LLMAttributionResult: """ Data class for the return result of LLMAttribution, @@ -52,17 +51,10 @@ class LLMAttributionResult: It also provides utilities to help present and plot the result in different forms. """ - def __init__( - self, - seq_attr: Tensor, - token_attr: Optional[Tensor], - input_tokens: List[str], - output_tokens: List[str], - ) -> None: - self.seq_attr = seq_attr - self.token_attr = token_attr - self.input_tokens = input_tokens - self.output_tokens = output_tokens + seq_attr: Tensor + token_attr: Optional[Tensor] + input_tokens: List[str] + output_tokens: List[str] @property def seq_attr_dict(self) -> Dict[str, float]: