import logging
from copy import deepcopy
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
import torch
from ..utils import (
abs_max,
drop_padding,
get_sequences_from_batched_steps,
identity_fn,
json_advanced_dump,
json_advanced_load,
normalize_attributions,
pretty_dict,
prod_fn,
remap_from_filtered,
sum_fn,
sum_normalize_attributions,
)
from ..utils.typing import (
MultipleScoresPerSequenceTensor,
MultipleScoresPerStepTensor,
OneOrMoreTokenWithIdSequences,
SequenceAttributionTensor,
SingleScorePerStepTensor,
SingleScoresPerSequenceTensor,
StepAttributionTensor,
TargetIdsTensor,
TextInput,
TokenWithId,
)
from .aggregator import AggregableMixin, Aggregator, AggregatorPipeline, SequenceAttributionAggregator
from .batch import Batch, BatchEncoding
from .data_utils import TensorWrapper
FeatureAttributionInput = Union[TextInput, BatchEncoding, Batch]
DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
"source_attributions": {"sequence_aggregate": identity_fn, "span_aggregate": abs_max},
"target_attributions": {"sequence_aggregate": identity_fn, "span_aggregate": abs_max},
"step_scores": {
"span_aggregate": {
"probability": prod_fn,
"entropy": sum_fn,
"crossentropy": sum_fn,
"perplexity": prod_fn,
"contrast_prob_diff": prod_fn,
"mc_dropout_prob_avg": prod_fn,
}
},
}
logger = logging.getLogger(__name__)
[docs]@dataclass(eq=False, repr=False)
class FeatureAttributionSequenceOutput(TensorWrapper, AggregableMixin):
"""
Output produced by a standard attribution method.
Attributes:
source (list of :class:`~inseq.utils.typing.TokenWithId`): Tokenized source sequence.
target (list of :class:`~inseq.utils.typing.TokenWithId`): Tokenized target sequence.
source_attributions (:obj:`SequenceAttributionTensor`): Tensor of shape (`source_len`,
`target_len`) plus an optional third dimension if the attribution is granular (e.g.
gradient attribution) containing the attribution scores produced at each generation step of
the target for every source token.
target_attributions (:obj:`SequenceAttributionTensor`, optional): Tensor of shape
(`target_len`, `target_len`), plus an optional third dimension if
the attribution is granular containing the attribution scores produced at each generation
step of the target for every token in the target prefix.
step_scores (:obj:`dict[str, SingleScorePerStepTensor]`, optional): Dictionary of step scores
produced alongside attributions (one per generation step).
sequence_scores (:obj:`dict[str, MultipleScoresPerStepTensor]`, optional): Dictionary of sequence
scores produced alongside attributions (n per generation step, as for attributions).
"""
source: List[TokenWithId]
target: List[TokenWithId]
source_attributions: Optional[SequenceAttributionTensor] = None
target_attributions: Optional[SequenceAttributionTensor] = None
step_scores: Optional[Dict[str, SingleScoresPerSequenceTensor]] = None
sequence_scores: Optional[Dict[str, MultipleScoresPerSequenceTensor]] = None
attr_pos_start: int = 0
attr_pos_end: Optional[int] = None
_aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None
_dict_aggregate_fn: Dict[str, Any] = None
def __post_init__(self):
aggregate_dict = DEFAULT_ATTRIBUTION_AGGREGATE_DICT
if self._dict_aggregate_fn is None or self._dict_aggregate_fn == {}:
self._dict_aggregate_fn = aggregate_dict
elif isinstance(self._dict_aggregate_fn, dict):
aggregate_dict.update(self._dict_aggregate_fn)
self._dict_aggregate_fn = aggregate_dict
if self._aggregator is None:
self._aggregator = SequenceAttributionAggregator
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)
[docs] @classmethod
def from_step_attributions(
cls,
attributions: List["FeatureAttributionStepOutput"],
tokenized_target_sentences: Optional[List[List[TokenWithId]]] = None,
pad_id: Optional[Any] = None,
has_bos_token: bool = True,
attr_pos_end: Optional[int] = None,
) -> List["FeatureAttributionSequenceOutput"]:
"""Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple
examples outputs per step into a list of :class:`~inseq.data.attribution.FeatureAttributionSequenceOutput` with
every object containing all step outputs for an individual example.
Raises:
`ValueError`: If the number of sequences in the attributions is not the same for all input sequences.
Returns:
`List[FeatureAttributionSequenceOutput]`: List of
:class:`~inseq.data.attribution.FeatureAttributionSequenceOutput` objects.
"""
attr = attributions[0]
seq_attr_cls = attr._sequence_cls
num_sequences = len(attr.prefix)
if not all([len(attr.prefix) == num_sequences for attr in attributions]):
raise ValueError("All the attributions must include the same number of sequences.")
seq_attributions = []
sources = None
if attr.source_attributions is not None:
sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)]
targets = [
drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences)
]
if tokenized_target_sentences is None:
tokenized_target_sentences = targets
if attr_pos_end is None:
attr_pos_end = max([len(t) for t in tokenized_target_sentences])
pos_start = [
min(len(tokenized_target_sentences[seq_id]), attr_pos_end) - len(targets[seq_id])
for seq_id in range(num_sequences)
]
for seq_id in range(num_sequences):
source = tokenized_target_sentences[seq_id][: pos_start[seq_id]] if sources is None else sources[seq_id]
seq_attributions.append(
seq_attr_cls(
source=source,
target=tokenized_target_sentences[seq_id],
attr_pos_start=pos_start[seq_id],
attr_pos_end=attr_pos_end,
)
)
if attr.source_attributions is not None:
source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions])
for seq_id in range(num_sequences):
# Remove padding from tensor
filtered_source_attribution = source_attributions[seq_id][
: len(sources[seq_id]), : len(targets[seq_id]), ...
]
seq_attributions[seq_id].source_attributions = filtered_source_attribution
if attr.target_attributions is not None:
target_attributions = get_sequences_from_batched_steps(
[att.target_attributions for att in attributions], pad_dims=(1,)
)
for seq_id in range(num_sequences):
if has_bos_token:
target_attributions[seq_id] = target_attributions[seq_id][1:, ...]
start_idx = max(pos_start) - pos_start[seq_id]
end_idx = start_idx + len(tokenized_target_sentences[seq_id])
target_attributions[seq_id] = target_attributions[seq_id][
start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203
]
if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]):
empty_final_row = torch.ones(1, *target_attributions[seq_id].shape[1:]) * float("nan")
target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0)
seq_attributions[seq_id].target_attributions = target_attributions[seq_id]
if attr.step_scores is not None:
step_scores = [{} for _ in range(num_sequences)]
for step_score_name in attr.step_scores.keys():
out_step_scores = get_sequences_from_batched_steps(
[att.step_scores[step_score_name] for att in attributions]
)
for seq_id in range(num_sequences):
step_scores[seq_id][step_score_name] = out_step_scores[seq_id][: len(targets[seq_id])]
for seq_id in range(num_sequences):
seq_attributions[seq_id].step_scores = step_scores[seq_id]
if attr.sequence_scores is not None:
seq_scores = [{} for _ in range(num_sequences)]
for seq_score_name in attr.sequence_scores.keys():
out_seq_scores = get_sequences_from_batched_steps(
[att.sequence_scores[seq_score_name] for att in attributions]
)
for seq_id in range(num_sequences):
seq_scores[seq_id][seq_score_name] = out_seq_scores[seq_id][
: len(sources[seq_id]), : len(targets[seq_id]), ...
]
for seq_id in range(num_sequences):
seq_attributions[seq_id].sequence_scores = seq_scores[seq_id]
return seq_attributions
[docs] def show(
self,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
display: bool = True,
return_html: Optional[bool] = False,
aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None,
do_aggregation: bool = True,
**kwargs,
) -> Optional[str]:
"""Visualize the attributions.
Args:
min_val (:obj:`int`, *optional*, defaults to None):
Minimum value in the color range of the visualization. If None, the minimum value of the attributions
across all visualized examples is used.
max_val (:obj:`int`, *optional*, defaults to None):
Maximum value in the color range of the visualization. If None, the maximum value of the attributions
across all visualized examples is used.
display (:obj:`bool`, *optional*, defaults to True):
Whether to display the visualization. Can be set to False if the visualization is produced and stored
for later use.
return_html (:obj:`bool`, *optional*, defaults to False):
Whether to return the HTML code of the visualization.
aggregator (:obj:`AggregatorPipeline`, *optional*, defaults to None):
Aggregates attributions before visualizing them. If not specified, the default aggregator for the class
is used.
do_aggregation (:obj:`bool`, *optional*, defaults to True):
Whether to aggregate the attributions before visualizing them. Allows to skip aggregation if the
attributions are already aggregated.
Returns:
:obj:`str`: The HTML code of the visualization if :obj:`return_html` is set to True, otherwise None.
"""
from inseq import show_attributions
# If no aggregator is specified, the default aggregator for the class is used
aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self
if (aggregated.source_attributions is not None and aggregated.source_attributions.shape[1] == 0) or (
aggregated.target_attributions is not None and aggregated.target_attributions.shape[1] == 0
):
tokens = "".join(tid.token for tid in self.target)
logger.warning(f"Found empty attributions, skipping attribution matching generation: {tokens}")
else:
return show_attributions(aggregated, min_val, max_val, display, return_html)
@property
def minimum(self) -> float:
minimum = 0
if self.source_attributions is not None:
minimum = min(minimum, float(torch.nan_to_num(self.source_attributions).min()))
if self.target_attributions is not None:
minimum = min(minimum, float(torch.nan_to_num(self.target_attributions).min()))
return minimum
@property
def maximum(self) -> float:
maximum = 0
if self.source_attributions is not None:
maximum = max(maximum, float(torch.nan_to_num(self.source_attributions).max()))
if self.target_attributions is not None:
maximum = max(maximum, float(torch.nan_to_num(self.target_attributions).max()))
return maximum
[docs] def weight_attributions(self, step_fn_id: str):
"""Weights attribution scores in place by the value of the selected step function for every generation step.
Args:
step_fn_id (`str`):
The id of the step function to use for weighting the attributions (e.g. ``probability``)
"""
aggregated_attr = self.aggregate()
step_scores = self.step_scores[step_fn_id].T.unsqueeze(1)
if self.source_attributions is not None:
source_attr = aggregated_attr.source_attributions.float().T
self.source_attributions = (step_scores * source_attr).T
if self.target_attributions is not None:
target_attr = aggregated_attr.target_attributions.float().T
self.target_attributions = (step_scores * target_attr).T
self._aggregator = AggregatorPipeline([])
return self
def get_scores_dicts(
self,
aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None,
do_aggregation: bool = True,
**kwargs,
) -> Dict[str, Dict[str, Dict[str, float]]]:
# If no aggregator is specified, the default aggregator for the class is used
aggr = self.aggregate(aggregator, **kwargs) if do_aggregation else self
return_dict = {"source_attributions": {}, "target_attributions": {}, "step_scores": {}}
if aggr.source_attributions is not None:
score_map_source = {}
for tgt_idx, tgt_tok in enumerate(aggr.target):
score_map_source[tgt_tok.token] = {}
for src_idx, src_tok in enumerate(aggr.source):
score_map_source[tgt_tok.token][src_tok.token] = aggr.source_attributions[src_idx, tgt_idx].item()
return_dict["source_attributions"] = score_map_source
if aggr.target_attributions is not None:
score_map_target = {}
for tgt_idx_b, tgt_tok_b in enumerate(aggr.target):
score_map_target[tgt_tok_b.token] = {}
for tgt_idx_a, tgt_tok_a in enumerate(aggr.target):
score_map_target[tgt_tok_b.token][tgt_tok_a.token] = aggr.target_attributions[
tgt_idx_a, tgt_idx_b
].item()
return_dict["target_attributions"] = score_map_target
if aggr.step_scores is not None:
step_scores_map = {}
for tgt_idx, tgt_tok in enumerate(aggr.target):
step_scores_map[tgt_tok.token] = {}
for step_score_id, step_score in aggr.step_scores.items():
step_scores_map[tgt_tok.token][step_score_id] = step_score[tgt_idx].item()
return_dict["step_scores"] = step_scores_map
return return_dict
[docs]@dataclass(eq=False, repr=False)
class FeatureAttributionStepOutput(TensorWrapper):
"""Output of a single step of feature attribution, plus extra information related to what was attributed."""
source_attributions: Optional[StepAttributionTensor] = None
step_scores: Optional[Dict[str, SingleScorePerStepTensor]] = None
target_attributions: Optional[StepAttributionTensor] = None
sequence_scores: Optional[Dict[str, MultipleScoresPerStepTensor]] = None
source: Optional[OneOrMoreTokenWithIdSequences] = None
prefix: Optional[OneOrMoreTokenWithIdSequences] = None
target: Optional[OneOrMoreTokenWithIdSequences] = None
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = FeatureAttributionSequenceOutput
def __post_init__(self):
self.to(torch.float32)
[docs] def remap_from_filtered(
self,
target_attention_mask: TargetIdsTensor,
) -> None:
"""Remaps the attributions to the original shape of the input sequence."""
if self.source_attributions is not None:
self.source_attributions = remap_from_filtered(
original_shape=(len(self.source), *self.source_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.source_attributions,
)
if self.target_attributions is not None:
self.target_attributions = remap_from_filtered(
original_shape=(len(self.prefix), *self.target_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.target_attributions,
)
if self.step_scores is not None:
for score_name, score_tensor in self.step_scores.items():
self.step_scores[score_name] = remap_from_filtered(
original_shape=(len(self.prefix), 1),
mask=target_attention_mask,
filtered=score_tensor.unsqueeze(-1),
).squeeze(-1)
if self.sequence_scores is not None:
for score_name, score_tensor in self.sequence_scores.items():
self.sequence_scores[score_name] = remap_from_filtered(
original_shape=(len(self.source), *self.source_attributions.shape[1:]),
mask=target_attention_mask,
filtered=score_tensor,
)
[docs]@dataclass
class FeatureAttributionOutput:
"""
Output produced by the `AttributionModel.attribute` method.
Attributes:
sequence_attributions (list of :class:`~inseq.data.FeatureAttributionSequenceOutput`): List
containing all attributions performed on input sentences (one per input sentence, including
source and optionally target-side attribution).
step_attributions (list of :class:`~inseq.data.FeatureAttributionStepOutput`, optional): List
containing all step attributions (one per generation step performed on the batch), returned if
`output_step_attributions=True`.
info (dict with str keys and any values): Dictionary including all available parameters used to
perform the attribution.
"""
# These fields of the info dictionary should be matching to allow merging
_merge_match_info_fields = [
"attribute_target",
"attribution_method",
"constrained_decoding",
"include_eos_baseline",
"model_class",
"model_name",
"step_scores",
"tokenizer_class",
"tokenizer_name",
]
sequence_attributions: List[FeatureAttributionSequenceOutput]
step_attributions: Optional[List[FeatureAttributionStepOutput]] = None
info: Dict[str, Any] = field(default_factory=dict)
def __str__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"
def __repr__(self):
return self.__str__()
def __eq__(self, other):
for self_seq, other_seq in zip(self.sequence_attributions, other.sequence_attributions):
if self_seq != other_seq:
return False
if self.step_attributions is not None and other.step_attributions is not None:
for self_step, other_step in zip(self.step_attributions, other.step_attributions):
if self_step != other_step:
return False
if self.info != other.info:
return False
return True
[docs] def save(
self,
path: PathLike,
overwrite: bool = False,
compress: bool = False,
ndarray_compact: bool = True,
use_primitives: bool = False,
split_sequences: bool = False,
) -> None:
"""
Save class contents to a JSON file.
Args:
path (:obj:`os.PathLike`): Path to the folder where the attribution output will be stored
(e.g. ``./out.json``).
overwrite (:obj:`bool`, *optional*, defaults to False):
If True, overwrite the file if it exists, raise error otherwise.
compress (:obj:`bool`, *optional*, defaults to False):
If True, the output file is compressed using gzip. Especially useful for large sequences and granular
attributions with umerged hidden dimensions.
ndarray_compact (:obj:`bool`, *optional*, defaults to True):
If True, the arrays for scores and attributions are stored in a compact b64 format. Otherwise, they are
stored as plain lists of floats.
use_primitives (:obj:`bool`, *optional*, defaults to False):
If True, the output is stored as a list of dictionaries with primitive types (e.g. int, float, str).
Note that an attribution saved with this option cannot be loaded with the `load` method.
split_sequences (:obj:`bool`, *optional*, defaults to False):
If True, the output is split into multiple files, one per sequence. The file names are generated by
appending the sequence index to the given path (e.g. ``./out.json`` with two sequences ->
``./out_0.json``, ``./out_1.json``)
"""
if not overwrite and Path(path).exists():
raise ValueError(f"{path} already exists. Override with overwrite=True.")
save_outs = []
paths = []
if split_sequences:
for i, seq in enumerate(self.sequence_attributions):
attr_out = deepcopy(self)
attr_out.sequence_attributions = [seq]
attr_out.step_attributions = None
attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]]
attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]]
save_outs.append(attr_out)
paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}")
else:
save_outs.append(self)
paths.append(path)
for attr_out, path_out in zip(save_outs, paths):
with open(path_out, f"w{'b' if compress else ''}") as f:
json_advanced_dump(
attr_out,
f,
allow_nan=True,
indent=4,
sort_keys=True,
ndarray_compact=ndarray_compact,
compression=compress,
use_primitives=use_primitives,
)
[docs] @staticmethod
def load(
path: PathLike,
decompress: bool = False,
) -> "FeatureAttributionOutput":
"""Load saved attribution output into a new :class:`~inseq.data.FeatureAttributionOutput` object.
Args:
path (:obj:`str`): Path to the JSON file containing the saved attribution output.
Note that the file must have been saved with the :meth:`~inseq.data.FeatureAttributionOutput.save`
method with ``use_primitives=False`` in order to be loaded correctly.
decompress (:obj:`bool`, *optional*, defaults to False):
If True, the input file is decompressed using gzip.
Returns:
:class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output
"""
out = json_advanced_load(path, decompression=decompress)
out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions]
if out.step_attributions is not None:
out.step_attributions = [step.torch() for step in out.step_attributions]
return out
[docs] def aggregate(
self,
aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None,
**kwargs,
) -> "FeatureAttributionOutput":
"""Aggregate the sequence attributions using one or more aggregators.
Args:
aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]`, optional): Aggregator
or pipeline to use. If not provided, the default aggregator for every sequence attribution
is used.
Returns:
:class:`~inseq.data.FeatureAttributionOutput`: Aggregated attribution output
"""
aggregated = deepcopy(self)
for idx, seq in enumerate(aggregated.sequence_attributions):
aggregated.sequence_attributions[idx] = seq.aggregate(aggregator, **kwargs)
return aggregated
[docs] def show(
self,
min_val: Optional[int] = None,
max_val: Optional[int] = None,
display: bool = True,
return_html: Optional[bool] = False,
aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None,
**kwargs,
) -> Optional[str]:
"""Visualize the sequence attributions.
Args:
min_val (int, optional): Minimum value for color scale.
max_val (int, optional): Maximum value for color scale.
display (bool, optional): If True, display the attribution visualization.
return_html (bool, optional): If True, return the attribution visualization as HTML.
aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]`, optional): Aggregator
or pipeline to use. If not provided, the default aggregator for every sequence attribution
is used.
Returns:
str: Attribution visualization as HTML if `return_html=True`, None otherwise.
"""
out_str = ""
for attr in self.sequence_attributions:
if return_html:
out_str += attr.show(min_val, max_val, display, return_html, aggregator, **kwargs)
else:
attr.show(min_val, max_val, display, return_html, aggregator, **kwargs)
if return_html:
return out_str
[docs] @classmethod
def merge_attributions(cls, attributions: List["FeatureAttributionOutput"]) -> "FeatureAttributionOutput":
"""Merges multiple :class:`~inseq.data.FeatureAttributionOutput` objects into a single one.
Merging is allowed only if the two outputs match on the fields specified in ``_merge_match_info_fields``.
Args:
attributions (`list(FeatureAttributionOutput)`): The FeatureAttributionOutput objects to be merged.
Returns:
`FeatureAttributionOutput`: Merged object
"""
assert all(
isinstance(x, FeatureAttributionOutput) for x in attributions
), "Only FeatureAttributionOutput objects can be merged."
first = attributions[0]
for match_field in cls._merge_match_info_fields:
assert all(
attr.info[match_field] == first.info[match_field]
if match_field in first.info
else match_field not in attr.info
for attr in attributions
), f"Cannot merge: incompatible values for field {match_field}"
out_info = first.info.copy()
if "attr_pos_end" in first.info:
out_info.update({"attr_pos_end": max(attr.info["attr_pos_end"] for attr in attributions)})
if "generated_texts" in first.info:
out_info.update(
{"generated_texts": [text for attr in attributions for text in attr.info["generated_texts"]]}
)
if "input_texts" in first.info:
out_info.update({"input_texts": [text for attr in attributions for text in attr.info["input_texts"]]})
return cls(
sequence_attributions=[seqattr for attr in attributions for seqattr in attr.sequence_attributions],
step_attributions=[stepattr for attr in attributions for stepattr in attr.step_attributions]
if first.step_attributions is not None
else None,
info=out_info,
)
def weight_attributions(self, step_score_id: str):
for i, attr in enumerate(self.sequence_attributions):
self.sequence_attributions[i] = attr.weight_attributions(step_score_id)
[docs] def get_scores_dicts(
self, aggregator: Union[AggregatorPipeline, Type[Aggregator]] = None, do_aggregation: bool = True, **kwargs
) -> List[Dict[str, Dict[str, Dict[str, float]]]]:
"""Get all computed scores (attributions and step scores) for all sequences as a list of dictionaries.
Returns:
:obj:`list(dict)`: List containing one dictionary per sequence. Every dictionary contains the keys
"source_attributions", "target_attributions" and "step_scores". For each of these keys, the value is a
dictionary with generated tokens as keys, and for values a final dictionary. For "step_scores", the keys
of the final dictionary are the step score ids, and the values are the scores.
For "source_attributions" and "target_attributions", the keys of the final dictionary are respectively
source and target tokens, and the values are the attribution scores.
This output is intended to be easily converted to a pandas DataFrame. The following example produces a list of
DataFrames, one for each sequence, matching the source attributions that would be visualized by out.show().
```python
dfs = [pd.DataFrame(x["source_attributions"]) for x in out.get_scores_dicts()]
```
"""
return [attr.get_scores_dicts(aggregator, do_aggregation, **kwargs) for attr in self.sequence_attributions]
# Gradient attribution classes
[docs]@dataclass(eq=False, repr=False)
class GradientFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput):
"""Raw output of a single sequence of gradient feature attribution.
Adds the convergence delta and default L2 + normalization merging of attributions to the base class.
"""
def __post_init__(self):
super().__post_init__()
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions
self._dict_aggregate_fn["target_attributions"]["sequence_aggregate"] = sum_normalize_attributions
if "deltas" not in self._dict_aggregate_fn["step_scores"]["span_aggregate"]:
self._dict_aggregate_fn["step_scores"]["span_aggregate"]["deltas"] = abs_max
[docs]@dataclass(eq=False, repr=False)
class GradientFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of gradient feature attribution.
Adds the convergence delta to the base class.
"""
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = GradientFeatureAttributionSequenceOutput
# Perturbation attribution classes
@dataclass(eq=False, repr=False)
class OcclusionFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput):
"""Raw output of a single sequence of occlusion feature attribution."""
def __post_init__(self):
super().__post_init__()
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = normalize_attributions
self._dict_aggregate_fn["target_attributions"]["sequence_aggregate"] = normalize_attributions
@dataclass(eq=False, repr=False)
class OcclusionFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of occlusion feature attribution."""
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = OcclusionFeatureAttributionSequenceOutput
@dataclass(eq=False, repr=False)
class PerturbationFeatureAttributionSequenceOutput(FeatureAttributionSequenceOutput):
"""Raw output of a single sequence of perturbation feature attribution."""
def __post_init__(self):
super().__post_init__()
self._dict_aggregate_fn["source_attributions"]["sequence_aggregate"] = sum_normalize_attributions
self._dict_aggregate_fn["target_attributions"]["sequence_aggregate"] = sum_normalize_attributions
@dataclass(eq=False, repr=False)
class PerturbationFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of perturbation feature attribution."""
_sequence_cls: Type["FeatureAttributionSequenceOutput"] = PerturbationFeatureAttributionSequenceOutput