import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from ..attr.feat import join_token_ids
from ..data import (
Batch,
BatchEmbedding,
BatchEncoding,
EncoderDecoderBatch,
FeatureAttributionInput,
FeatureAttributionStepOutput,
)
from ..utils import pretty_tensor
from ..utils.typing import (
AttributionForwardInputs,
EmbeddingsTensor,
ExpandedTargetIdsTensor,
FullLogitsTensor,
IdsTensor,
OneOrMoreTokenSequences,
SingleScorePerStepTensor,
TargetIdsTensor,
TextSequences,
TokenWithId,
)
from .attribution_model import AttributionModel, ModelOutput
logger = logging.getLogger(__name__)
[docs]class EncoderDecoderAttributionModel(AttributionModel):
"""AttributionModel class for attributing encoder-decoder models."""
@staticmethod
def format_forward_args(
inputs: EncoderDecoderBatch,
use_embeddings: bool = True,
) -> Dict[str, Any]:
return {
"forward_tensor": inputs.sources.input_embeds if use_embeddings else inputs.sources.input_ids,
"decoder_input_embeds": inputs.targets.input_embeds,
# "decoder_input_ids": inputs.targets.input_ids,
"encoder_attention_mask": inputs.sources.attention_mask,
"decoder_attention_mask": inputs.targets.attention_mask,
}
@staticmethod
def format_attribution_args(
batch: EncoderDecoderBatch,
target_ids: TargetIdsTensor,
attributed_fn: Callable[..., SingleScorePerStepTensor],
attribute_target: bool = False,
attributed_fn_args: Dict[str, Any] = {},
attribute_batch_ids: bool = False,
forward_batch_embeds: bool = True,
**kwargs,
) -> Tuple[Dict[str, Any], Tuple[Union[IdsTensor, EmbeddingsTensor, None], ...]]:
if attribute_batch_ids:
inputs = (batch.sources.input_ids,)
baselines = (batch.sources.baseline_ids,)
else:
inputs = (batch.sources.input_embeds,)
baselines = (batch.sources.baseline_embeds,)
if attribute_target:
if attribute_batch_ids:
inputs += (batch.targets.input_ids,)
baselines += (batch.targets.baseline_ids,)
else:
inputs += (batch.targets.input_embeds,)
baselines += (batch.targets.baseline_embeds,)
attribute_fn_args = {
"inputs": inputs,
"additional_forward_args": (
# Ids are always explicitly passed as extra arguments to enable
# usage in custom attribution functions.
batch.sources.input_ids,
batch.targets.input_ids,
# Making targets 2D enables _expand_additional_forward_args
# in Captum to preserve the expected batch dimension for methods
# such as intergrated gradients.
target_ids.unsqueeze(-1),
attributed_fn,
batch.sources.attention_mask,
batch.targets.attention_mask,
# Defines how to treat source and target tensors
# Maps on the use_embeddings argument of forward
forward_batch_embeds,
list(attributed_fn_args.keys()),
)
+ tuple(attributed_fn_args.values()),
}
if not attribute_target:
attribute_fn_args["additional_forward_args"] = (batch.targets.input_embeds,) + attribute_fn_args[
"additional_forward_args"
]
return attribute_fn_args, baselines
def get_text_sequences(self, batch: EncoderDecoderBatch) -> TextSequences:
return TextSequences(
sources=self.convert_tokens_to_string(batch.sources.input_tokens),
targets=self.convert_tokens_to_string(batch.targets.input_tokens, as_targets=True),
)
[docs] @staticmethod
def enrich_step_output(
step_output: FeatureAttributionStepOutput,
batch: EncoderDecoderBatch,
target_tokens: OneOrMoreTokenSequences,
target_ids: TargetIdsTensor,
) -> FeatureAttributionStepOutput:
r"""
Enriches the attribution output with token information, producing the finished
:class:`~inseq.data.FeatureAttributionStepOutput` object.
Args:
step_output (:class:`~inseq.data.FeatureAttributionStepOutput`): The output produced
by the attribution step, with missing batch information.
batch (:class:`~inseq.data.EncoderDecoderBatch`): The batch on which attribution was performed.
target_ids (:obj:`torch.Tensor`): Target token ids of size `(batch_size, 1)` corresponding to tokens
for which the attribution step was performed.
Returns:
:class:`~inseq.data.FeatureAttributionStepOutput`: The enriched attribution output.
"""
if len(target_ids.shape) == 0:
target_ids = target_ids.unsqueeze(0)
step_output.source = join_token_ids(batch.sources.input_tokens, batch.sources.input_ids.tolist())
step_output.target = [[TokenWithId(token[0], id)] for token, id in zip(target_tokens, target_ids.tolist())]
step_output.prefix = join_token_ids(batch.targets.input_tokens, batch.targets.input_ids.tolist())
return step_output
def format_step_function_args(
self,
forward_output: ModelOutput,
target_ids: ExpandedTargetIdsTensor,
encoder_input_ids: Optional[IdsTensor] = None,
decoder_input_ids: Optional[IdsTensor] = None,
encoder_input_embeds: Optional[EmbeddingsTensor] = None,
decoder_input_embeds: Optional[EmbeddingsTensor] = None,
encoder_attention_mask: Optional[IdsTensor] = None,
decoder_attention_mask: Optional[IdsTensor] = None,
**kwargs,
) -> Dict[str, Any]:
return {
**kwargs,
**{
"attribution_model": self,
"forward_output": forward_output,
"encoder_input_ids": encoder_input_ids,
"decoder_input_ids": decoder_input_ids,
"encoder_input_embeds": encoder_input_embeds,
"decoder_input_embeds": decoder_input_embeds,
"target_ids": target_ids,
"encoder_attention_mask": encoder_attention_mask,
"decoder_attention_mask": decoder_attention_mask,
},
}
def get_forward_output(
self,
forward_tensor: AttributionForwardInputs,
encoder_attention_mask: Optional[IdsTensor] = None,
decoder_input_embeds: Optional[EmbeddingsTensor] = None,
decoder_attention_mask: Optional[IdsTensor] = None,
use_embeddings: bool = True,
**kwargs,
) -> ModelOutput:
encoder_embeds = forward_tensor if use_embeddings else None
encoder_ids = None if use_embeddings else forward_tensor
return self.model(
input_ids=encoder_ids,
inputs_embeds=encoder_embeds,
attention_mask=encoder_attention_mask,
decoder_inputs_embeds=decoder_input_embeds,
decoder_attention_mask=decoder_attention_mask,
**kwargs,
)
[docs] def forward(
self,
encoder_tensors: AttributionForwardInputs,
decoder_input_embeds: AttributionForwardInputs,
encoder_input_ids: IdsTensor,
decoder_input_ids: IdsTensor,
target_ids: ExpandedTargetIdsTensor,
attributed_fn: Callable[..., SingleScorePerStepTensor],
encoder_attention_mask: Optional[IdsTensor] = None,
decoder_attention_mask: Optional[IdsTensor] = None,
use_embeddings: bool = True,
attributed_fn_argnames: Optional[List[str]] = None,
*args,
) -> FullLogitsTensor:
assert len(args) == len(attributed_fn_argnames), "Number of arguments and number of argnames must match"
target_ids = target_ids.squeeze(-1)
output = self.get_forward_output(
forward_tensor=encoder_tensors,
encoder_attention_mask=encoder_attention_mask,
decoder_input_embeds=decoder_input_embeds,
decoder_attention_mask=decoder_attention_mask,
use_embeddings=use_embeddings,
)
logger.debug(f"logits: {pretty_tensor(output.logits)}")
step_function_args = self.format_step_function_args(
attribution_model=self,
forward_output=output,
encoder_input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids,
encoder_input_embeds=encoder_tensors if use_embeddings else None,
decoder_input_embeds=decoder_input_embeds,
target_ids=target_ids,
encoder_attention_mask=encoder_attention_mask,
decoder_attention_mask=decoder_attention_mask,
**{k: v for k, v in zip(attributed_fn_argnames, args) if v is not None},
)
return attributed_fn(**step_function_args)