Skip to content

Convert LLM Attribution result to a proper dataclass #1572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
"source_directories": [
"."
],
"strict": true
"strict": true,
"version": "0.0.101745838703"
}
20 changes: 6 additions & 14 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import warnings

from abc import ABC

from copy import copy

from dataclasses import dataclass
from textwrap import shorten

from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

import matplotlib.colors as mcolors
Expand Down Expand Up @@ -45,24 +43,18 @@
}


@dataclass
class LLMAttributionResult:
"""
Data class for the return result of LLMAttribution,
which includes the necessary properties of the attribution.
It also provides utilities to help present and plot the result in different forms.
"""

def __init__(
self,
seq_attr: Tensor,
token_attr: Optional[Tensor],
input_tokens: List[str],
output_tokens: List[str],
) -> None:
self.seq_attr = seq_attr
self.token_attr = token_attr
self.input_tokens = input_tokens
self.output_tokens = output_tokens
seq_attr: Tensor
token_attr: Optional[Tensor]
input_tokens: List[str]
output_tokens: List[str]

@property
def seq_attr_dict(self) -> Dict[str, float]:
Expand Down
6 changes: 2 additions & 4 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,6 @@ def _generate_perturbations(
repeated instances per example.
"""

# pyre-fixme[53]: Captured variable `baselines_expanded` is not annotated.
# pyre-fixme[53]: Captured variable `inputs_expanded` is not annotated.
def call_perturb_func() -> (
Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
):
Expand All @@ -520,12 +518,12 @@ def call_perturb_func() -> (
else perturb_func(inputs_pert)
)

inputs_expanded = tuple(
inputs_expanded: Tuple[Tensor, ...] = tuple(
torch.repeat_interleave(input, current_n_perturb_samples, dim=0)
for input in inputs
)

baselines_expanded = baselines
baselines_expanded: BaselineTupleType = baselines
if baselines is not None:
baselines_expanded = tuple(
(
Expand Down
11 changes: 3 additions & 8 deletions captum/metrics/_core/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def default_perturb_func(

@log_usage(part_of_slo=True)
def sensitivity_max(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
explanation_func: Callable,
explanation_func: Callable[..., TensorOrTupleOfTensorsGeneric],
inputs: TensorOrTupleOfTensorsGeneric,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
perturb_func: Callable = default_perturb_func,
Expand Down Expand Up @@ -232,8 +231,6 @@ def max_values(input_tnsr: Tensor) -> Tensor:
# pyre-fixme[33]: Given annotation cannot be `Any`.
kwargs_copy: Any = None

# pyre-fixme[53]: Captured variable `bsz` is not annotated.
# pyre-fixme[53]: Captured variable `expl_inputs` is not annotated.
def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
inputs_perturbed = _generate_perturbations(current_n_perturb_samples)

Expand Down Expand Up @@ -281,8 +278,6 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
[
(expl_input - expl_perturbed).view(expl_perturbed.size(0), -1)
for expl_perturbed, expl_input in zip(
# pyre-fixme[6]: For 1st argument expected
# `Iterable[Variable[_T1]]` but got `None`.
expl_perturbed_inputs,
expl_inputs_expanded,
)
Expand Down Expand Up @@ -318,10 +313,10 @@ def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:

inputs = _format_tensor_into_tuples(inputs) # type: ignore

bsz = inputs[0].size(0)
bsz: int = inputs[0].size(0)

with torch.no_grad():
expl_inputs = explanation_func(inputs, **kwargs)
expl_inputs: TensorOrTupleOfTensorsGeneric = explanation_func(inputs, **kwargs)
metrics_max = _divide_and_aggregate_metrics(
cast(Tuple[Tensor, ...], inputs),
n_perturb_samples,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def report(*args):
"sphinx-autodoc-typehints",
"sphinxcontrib-katex",
"mypy>=0.760",
"pyre-check-nightly",
"pyre-check-nightly==0.0.101745838703",
"usort==1.0.2",
"ufmt",
"scikit-learn",
Expand Down
Loading