# Copyright 2021 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Attention-based feature attribution methods. """
import logging
from typing import Any, Callable, Dict, Union
from ...data import Batch, EncoderDecoderBatch, FeatureAttributionStepOutput
from ...utils import Registry, pretty_tensor
from ...utils.typing import SingleScorePerStepTensor, TargetIdsTensor
from ..attribution_decorators import set_hook, unset_hook
from ..step_functions import STEP_SCORES_MAP
from .attribution_utils import get_source_target_attributions
from .feature_attribution import FeatureAttribution
from .ops import AttentionWeights
logger = logging.getLogger(__name__)
[docs]class InternalsAttributionRegistry(FeatureAttribution, Registry):
r"""Model Internals-based attribution method registry."""
[docs] @set_hook
def hook(self, **kwargs):
pass
[docs] @unset_hook
def unhook(self, **kwargs):
pass
[docs] def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""
Performs a single attribution step for the specified attribution arguments.
Args:
attribute_fn_main_args (:obj:`dict`): Main arguments used for the attribution method. These are built from
model inputs at the current step of the feature attribution process.
attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method.
These can be specified by the user while calling the top level `attribute` methods. Defaults to {}.
Returns:
:class:`~inseq.data.FeatureAttributionStepOutput`: A dataclass containing a tensor of source
attributions of size `(batch_size, source_length)`, possibly a tensor of target attributions of size
`(batch_size, prefix length) if attribute_target=True and possibly a tensor of deltas of size
`(batch_size)` if the attribution step supports deltas and they are requested. At this point the batch
information is empty, and will later be filled by the enrich_step_output function.
"""
attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
source_attributions, target_attributions = get_source_target_attributions(
attr, self.attribution_model.is_encoder_decoder
)
return FeatureAttributionStepOutput(
source_attributions=source_attributions,
target_attributions=target_attributions,
step_scores={},
)
[docs]class AttentionWeightsAttribution(InternalsAttributionRegistry):
"""
The basic attention attribution method, which retrieves the attention weights from the model.
Attribute Args:
aggregate_heads_fn (:obj:`str` or :obj:`callable`): The method to use for aggregating across heads.
Can be one of `average` (default if heads is list, tuple or None), `max`, `min` or `single` (default
if heads is int), or a custom function defined by the user.
aggregate_layers_fn (:obj:`str` or :obj:`callable`): The method to use for aggregating across layers.
Can be one of `average` (default if layers is tuple or list), `max`, `min` or `single` (default if
layers is int or None), or a custom function defined by the user.
heads (:obj:`int` or :obj:`tuple[int, int]` or :obj:`list(int)`, optional): If a single value is specified,
the head at the corresponding index is used. If a tuple of two indices is specified, all heads between
the indices will be aggregated using aggregate_fn. If a list of indices is specified, the respective
heads will be used for aggregation. If aggregate_fn is "single", a head must be specified.
If no value is specified, all heads are passed to aggregate_fn by default.
layers (:obj:`int` or :obj:`tuple[int, int]` or :obj:`list(int)`, optional): If a single value is specified
, the layer at the corresponding index is used. If a tuple of two indices is specified, all layers
among the indices will be aggregated using aggregate_fn. If a list of indices is specified, the
respective layers will be used for aggregation. If aggregate_fn is "single", the last layer is
used by default. If no value is specified, all available layers are passed to aggregate_fn by default.
Example:
- ``model.attribute(src)`` will return the average attention for all heads of the last layer.
- ``model.attribute(src, heads=0)`` will return the attention weights for the first head of the last layer.
- ``model.attribute(src, heads=(0, 5), aggregate_heads_fn="max", layers=[0, 2, 7])`` will return the maximum
attention weights for the first 5 heads averaged across the first, third, and eighth layers.
"""
method_name = "attention"
def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = AttentionWeights(attribution_model)