Source code for inseq.models.decoder_only

import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from ..attr.feat import join_token_ids
from ..data import (
    Batch,
    BatchEmbedding,
    BatchEncoding,
    DecoderOnlyBatch,
    FeatureAttributionInput,
    FeatureAttributionStepOutput,
)
from ..utils import pretty_tensor
from ..utils.typing import (
    AttributionForwardInputs,
    EmbeddingsTensor,
    ExpandedTargetIdsTensor,
    FullLogitsTensor,
    IdsTensor,
    OneOrMoreTokenSequences,
    SingleScorePerStepTensor,
    TargetIdsTensor,
    TextSequences,
    TokenWithId,
)
from .attribution_model import AttributionModel, ModelOutput

logger = logging.getLogger(__name__)


[docs]class DecoderOnlyAttributionModel(AttributionModel): """AttributionModel class for attributing encoder-decoder models.""" def prepare_inputs_for_attribution( self, inputs: FeatureAttributionInput, include_eos_baseline: bool = False, ) -> DecoderOnlyBatch: if isinstance(inputs, Batch): batch = inputs else: if isinstance(inputs, (str, list)): # Decoder-only model do not tokenize as targets, # since a single tokenizer is available. encodings: BatchEncoding = self.encode( inputs, return_baseline=True, include_eos_baseline=include_eos_baseline, ) elif isinstance(inputs, BatchEncoding): encodings = inputs else: raise ValueError( "targets must be either a string, a list of strings, a BatchEncoding or a Batch, " f"not {type(inputs)}" ) baseline_embeds = self.embed(encodings.baseline_ids) embeddings = BatchEmbedding( input_embeds=self.embed(encodings.input_ids), baseline_embeds=baseline_embeds, ) batch = DecoderOnlyBatch(encodings, embeddings) return batch @staticmethod def format_forward_args( inputs: DecoderOnlyBatch, use_embeddings: bool = True, ) -> Dict[str, Any]: return { "forward_tensor": inputs.input_embeds if use_embeddings else inputs.input_ids, "attention_mask": inputs.attention_mask, } @staticmethod def format_attribution_args( batch: DecoderOnlyBatch, target_ids: TargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], attributed_fn_args: Dict[str, Any] = {}, attribute_batch_ids: bool = False, forward_batch_embeds: bool = True, **kwargs, ) -> Tuple[Dict[str, Any], Tuple[Union[IdsTensor, EmbeddingsTensor, None], ...]]: if attribute_batch_ids: inputs = (batch.input_ids,) baselines = (batch.baseline_ids,) else: inputs = (batch.input_embeds,) baselines = (batch.baseline_embeds,) attribute_fn_args = { "inputs": inputs, "additional_forward_args": ( # Ids are always explicitly passed as extra arguments to enable # usage in custom attribution functions. batch.input_ids, # Making targets 2D enables _expand_additional_forward_args # in Captum to preserve the expected batch dimension for methods # such as intergrated gradients. target_ids.unsqueeze(-1), attributed_fn, batch.attention_mask, # Defines how to treat source and target tensors # Maps on the use_embeddings argument of forward forward_batch_embeds, list(attributed_fn_args.keys()), ) + tuple(attributed_fn_args.values()), } return attribute_fn_args, baselines def get_text_sequences(self, batch: DecoderOnlyBatch) -> TextSequences: return TextSequences( sources=None, targets=self.convert_tokens_to_string(batch.input_tokens, as_targets=True), )
[docs] @staticmethod def enrich_step_output( step_output: FeatureAttributionStepOutput, batch: DecoderOnlyBatch, target_tokens: OneOrMoreTokenSequences, target_ids: TargetIdsTensor, ) -> FeatureAttributionStepOutput: r""" Enriches the attribution output with token information, producing the finished :class:`~inseq.data.FeatureAttributionStepOutput` object. Args: step_output (:class:`~inseq.data.FeatureAttributionStepOutput`): The output produced by the attribution step, with missing batch information. batch (:class:`~inseq.data.DecoderOnlyBatch`): The batch on which attribution was performed. target_ids (:obj:`torch.Tensor`): Target token ids of size `(batch_size, 1)` corresponding to tokens for which the attribution step was performed. Returns: :class:`~inseq.data.FeatureAttributionStepOutput`: The enriched attribution output. """ if len(target_ids.shape) == 0: target_ids = target_ids.unsqueeze(0) step_output.source = None step_output.target = [[TokenWithId(token[0], id)] for token, id in zip(target_tokens, target_ids.tolist())] step_output.prefix = join_token_ids(batch.target_tokens, batch.input_ids.tolist()) return step_output
def format_step_function_args( self, forward_output: ModelOutput, target_ids: ExpandedTargetIdsTensor, decoder_input_ids: Optional[IdsTensor] = None, decoder_input_embeds: Optional[EmbeddingsTensor] = None, decoder_attention_mask: Optional[IdsTensor] = None, **kwargs, ) -> Dict[str, Any]: return { **kwargs, **{ "attribution_model": self, "forward_output": forward_output, "encoder_input_ids": None, "decoder_input_ids": decoder_input_ids, "encoder_input_embeds": None, "decoder_input_embeds": decoder_input_embeds, "target_ids": target_ids, "encoder_attention_mask": None, "decoder_attention_mask": decoder_attention_mask, **kwargs, }, } def get_forward_output( self, forward_tensor: AttributionForwardInputs, attention_mask: Optional[IdsTensor] = None, use_embeddings: bool = True, **kwargs, ) -> ModelOutput: embeds = forward_tensor if use_embeddings else None ids = None if use_embeddings else forward_tensor return self.model( input_ids=ids, inputs_embeds=embeds, attention_mask=attention_mask, **kwargs, )
[docs] def forward( self, forward_tensor: AttributionForwardInputs, input_ids: IdsTensor, target_ids: ExpandedTargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], attention_mask: Optional[IdsTensor] = None, use_embeddings: bool = True, attributed_fn_argnames: Optional[List[str]] = None, *args, ) -> FullLogitsTensor: assert len(args) == len(attributed_fn_argnames), "Number of arguments and number of argnames must match" target_ids = target_ids.squeeze(-1) output = self.get_forward_output( forward_tensor=forward_tensor, attention_mask=attention_mask, use_embeddings=use_embeddings, ) logger.debug(f"logits: {pretty_tensor(output.logits)}") step_function_args = self.format_step_function_args( attribution_model=self, forward_output=output, decoder_input_ids=input_ids, decoder_input_embeds=forward_tensor if use_embeddings else None, target_ids=target_ids, decoder_attention_mask=attention_mask, **{k: v for k, v in zip(attributed_fn_argnames, args) if v is not None}, ) return attributed_fn(**step_function_args)