"""HuggingFace Seq2seq model."""
import logging
from abc import abstractmethod
from typing import Any, NoReturn
import torch
from torch import long
from transformers.modeling_outputs import CausalLMOutput, Seq2SeqLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils.generic import ModelOutput
from ..attr.attribution_decorators import batched
from ..data import BatchEncoding
from ..utils import check_device
from ..utils.typing import (
EmbeddingsTensor,
IdsTensor,
LogitsTensor,
MultiLayerEmbeddingsTensor,
MultiLayerMultiUnitScoreTensor,
OneOrMoreIdSequences,
OneOrMoreTokenSequences,
OneOrMoreTokenWithIdSequences,
TextInput,
TokenWithId,
VocabularyEmbeddingsTensor,
)
from .attribution_model import AttributionModel
from .decoder_only import DecoderOnlyAttributionModel
from .encoder_decoder import EncoderDecoderAttributionModel
from .model_decorators import unhooked
logger = logging.getLogger(__name__)
logging.getLogger("urllib3").setLevel(logging.WARNING)
# Update if other model types are added
SUPPORTED_AUTOCLASSES = [AutoModelForSeq2SeqLM, AutoModelForCausalLM]
[docs]
class HuggingfaceModel(AttributionModel):
"""Model wrapper for any ForCausalLM and ForConditionalGeneration model on the HuggingFace Hub used to enable
feature attribution. Corresponds to AutoModelForCausalLM and AutoModelForSeq2SeqLM auto classes.
Attributes:
_autoclass (:obj:`Type[transformers.AutoModel`]): The HuggingFace model class to use for initialization.
Must be defined in subclasses.
model (:obj:`transformers.AutoModelForSeq2SeqLM` or :obj:`transformers.AutoModelForSeq2SeqLM`):
the model on which attribution is performed.
tokenizer (:obj:`transformers.AutoTokenizer`): the tokenizer associated to the model.
device (:obj:`str`): the device on which the model is run.
encoder_int_embeds (:obj:`captum.InterpretableEmbeddingBase`): the interpretable embedding layer for the
encoder, used for layer attribution methods in Captum.
decoder_int_embeds (:obj:`captum.InterpretableEmbeddingBase`): the interpretable embedding layer for the
decoder, used for layer attribution methods in Captum.
embed_scale (:obj:`float`, *optional*): scale factor for embeddings.
tokenizer_name (:obj:`str`, *optional*): The name of the tokenizer in the Huggingface Hub.
Default: use model name.
"""
_autoclass = None
def __init__(
self,
model: str | PreTrainedModel,
attribution_method: str | None = None,
tokenizer: str | PreTrainedTokenizerBase | None = None,
device: str | None = None,
model_kwargs: dict[str, Any] | None = {},
tokenizer_kwargs: dict[str, Any] | None = {},
**kwargs,
) -> None:
"""AttributionModel subclass for Huggingface-compatible models.
Args:
model (:obj:`str` or :obj:`transformers.PreTrainedModel`): the name of the model in the
Huggingface Hub or path to folder containing local model files.
attribution_method (str, optional): The attribution method to use.
Passing it here reduces overhead on attribute call, since it is already
initialized.
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizerBase`, optional): the name of the tokenizer
in the Huggingface Hub or path to folder containing local tokenizer files.
Default: use model name.
device (str, optional): the Torch device on which the model is run.
**kwargs: additional arguments for the model and the tokenizer.
"""
super().__init__(**kwargs)
if "attn_implementation" not in model_kwargs:
model_kwargs["attn_implementation"] = "eager"
if self._autoclass is None or self._autoclass not in SUPPORTED_AUTOCLASSES:
raise ValueError(
f"Invalid autoclass {self._autoclass}. Must be one of {[x.__name__ for x in SUPPORTED_AUTOCLASSES]}."
)
if isinstance(model, PreTrainedModel):
self.model = model
else:
self.model = self._autoclass.from_pretrained(model, **model_kwargs)
# In transformers v5+, checkpoints containing both shared.weight and lm_head.weight
# are loaded without tying them, even when tie_word_embeddings=True. This causes
# scale_decoder_outputs=True to double-scale the logits (the untied lm_head already
# encodes the scaling). Detect this case and disable the redundant scaling.
if (
getattr(self.model.config, "tie_word_embeddings", False)
and getattr(self.model.config, "scale_decoder_outputs", False)
and hasattr(self.model, "lm_head")
and hasattr(self.model, "get_input_embeddings")
):
embed_layer = self.model.get_input_embeddings()
if embed_layer is not None and hasattr(embed_layer, "weight") and hasattr(self.model.lm_head, "weight"):
if embed_layer.weight.data_ptr() != self.model.lm_head.weight.data_ptr():
logger.warning(
"tie_word_embeddings=True but weights are not tied. "
"Disabling scale_decoder_outputs to avoid double-scaling logits."
)
self.model.config.scale_decoder_outputs = False
self.model_name = self.model.config.name_or_path
self.tokenizer_name = tokenizer if isinstance(tokenizer, str) else None
if tokenizer is None:
tokenizer = model if isinstance(model, str) else self.model_name
if not tokenizer:
raise ValueError(
"Unspecified tokenizer for model loaded from scratch. Use explicit identifier as tokenizer=<ID>"
"during model loading."
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0]
pad_token_id = self.model.config.pad_token_id
if pad_token_id is None:
if self.tokenizer.pad_token_id is None:
logger.info(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.")
pad_token_id = self.eos_token_id
else:
pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self._convert_ids_to_tokens(pad_token_id, skip_special_tokens=False)
if isinstance(self.pad_token, list):
self.pad_token = self.pad_token[0]
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.pad_token
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = pad_token_id
self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None)
if self.bos_token_id is None:
self.bos_token_id = self.model.config.bos_token_id
self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False)
if self.eos_token_id is None:
self.eos_token_id = self.tokenizer.pad_token_id
if self.tokenizer.unk_token_id is None:
self.tokenizer.unk_token_id = self.tokenizer.pad_token_id
self.embed_scale = 1.0
self.encoder_int_embeds = None
self.decoder_int_embeds = None
self.device_map = None
if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None:
self.device_map = self.model.hf_device_map
self.is_encoder_decoder = self.model.config.is_encoder_decoder
self.configure_embeddings_scale()
self.setup(device, attribution_method, **kwargs)
[docs]
@staticmethod
def load(
model: str | PreTrainedModel,
attribution_method: str | None = None,
tokenizer: str | PreTrainedTokenizerBase | None = None,
device: str = None,
model_kwargs: dict[str, Any] | None = {},
tokenizer_kwargs: dict[str, Any] | None = {},
**kwargs,
) -> "HuggingfaceModel":
"""Loads a HuggingFace model and tokenizer and wraps them in the appropriate AttributionModel."""
if isinstance(model, str):
is_encoder_decoder = AutoConfig.from_pretrained(model, **model_kwargs).is_encoder_decoder
else:
is_encoder_decoder = model.config.is_encoder_decoder
if is_encoder_decoder:
return HuggingfaceEncoderDecoderModel(
model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs
)
else:
return HuggingfaceDecoderOnlyModel(
model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs
)
@AttributionModel.device.setter
def device(self, new_device: str) -> None:
check_device(new_device)
self._device = new_device
is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
is_quantized = is_loaded_in_8bit or is_loaded_in_4bit
has_device_map = self.device_map is not None
# Enable compatibility with 8bit models
if self.model:
if is_quantized:
mode = "8bit" if is_loaded_in_8bit else "4bit"
logger.warning(
f"The model is loaded in {mode} mode. The device cannot be changed after loading the model."
)
elif has_device_map:
logger.warning("The model is loaded with a device map. The device cannot be changed after loading.")
else:
self.model.to(self._device)
@property
def info(self) -> dict[str, str]:
dic_info: dict[str, str] = super().info
extra_info = {
"tokenizer_name": self.tokenizer_name,
"tokenizer_class": self.tokenizer.__class__.__name__,
}
dic_info.update(extra_info)
return dic_info
[docs]
@unhooked
@batched
def generate(
self,
inputs: TextInput | BatchEncoding,
return_generation_output: bool = False,
skip_special_tokens: bool | None = None,
output_generated_only: bool = False,
**kwargs,
) -> list[str] | tuple[list[str], ModelOutput]:
"""Wrapper of model.generate to handle tokenization and decoding.
Args:
inputs (`Union[TextInput, BatchEncoding]`):
Inputs to be provided to the model for generation.
return_generation_output (`bool`, *optional*, defaults to False):
If true, generation outputs are returned alongside the generated text.
output_generated_only (`bool`, *optional*, defaults to False):
If true, only the generated text is returned. Relevant for decoder-only models that would otherwise return
the full input + output.
Returns:
`Union[List[str], Tuple[List[str], ModelOutput]]`: Generated text or a tuple of generated text and
generation outputs.
"""
if isinstance(inputs, str) or (
isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs)
):
inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens)
inputs: BatchEncoding = inputs.to(self.device)
generation_out = self.model.generate(
inputs=inputs.input_ids,
return_dict_in_generate=True,
**kwargs,
)
sequences = generation_out.sequences
if output_generated_only and not self.is_encoder_decoder:
sequences = sequences[:, inputs.input_ids.shape[1] :]
# Left-padding in multi-sentence sequences is skipped by default.
if skip_special_tokens is None:
skip_special_tokens = inputs.num_sequences != 1 or self.is_encoder_decoder
texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens)
if return_generation_output:
return texts, generation_out
return texts
@staticmethod
def output2logits(forward_output: Seq2SeqLMOutput | CausalLMOutput) -> LogitsTensor:
# Full logits for last position of every sentence:
# (batch_size, tgt_seq_len, vocab_size) => (batch_size, vocab_size)
return forward_output.logits[:, -1, :].squeeze(1)
[docs]
def encode(
self,
texts: TextInput,
as_targets: bool = False,
return_baseline: bool = False,
include_eos_baseline: bool = False,
add_bos_token: bool = True,
add_special_tokens: bool = True,
) -> BatchEncoding:
"""Encode one or multiple texts, producing a BatchEncoding.
Args:
texts (str or list of str): the texts to tokenize.
return_baseline (bool, optional): if True, baseline token ids are returned.
Returns:
BatchEncoding: contains ids and attention masks.
"""
if as_targets and not self.is_encoder_decoder:
raise ValueError("Decoder-only models should use tokenization as source only.")
batch = self.tokenizer(
text=texts if not as_targets else None,
text_target=texts if as_targets else None,
add_special_tokens=add_special_tokens,
padding=True,
truncation=True,
return_tensors="pt",
).to(self.device)
baseline_ids = None
# Fix: If two BOS tokens are present (e.g. when using chat templates), the second one is removed.
if (
batch["input_ids"].shape[0] == 1
and len(batch["input_ids"][0]) >= 2
and batch["input_ids"][0][0] == batch["input_ids"][0][1] == self.bos_token_id
):
batch["input_ids"] = batch["input_ids"][:, 1:]
batch["attention_mask"] = batch["attention_mask"][:, 1:]
if return_baseline:
if include_eos_baseline:
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.tokenizer.unk_token_id
else:
baseline_ids_non_eos = batch["input_ids"].ne(self.eos_token_id).long() * self.tokenizer.unk_token_id
baseline_ids_eos = batch["input_ids"].eq(self.eos_token_id).long() * self.eos_token_id
baseline_ids = baseline_ids_non_eos + baseline_ids_eos
# We prepend a BOS token only when tokenizing target texts.
if as_targets and self.is_encoder_decoder and add_bos_token:
ones_mask = torch.ones((batch["input_ids"].shape[0], 1), device=self.device, dtype=long)
batch["attention_mask"] = torch.cat((ones_mask, batch["attention_mask"]), dim=1)
bos_ids = ones_mask * self.bos_token_id
batch["input_ids"] = torch.cat((bos_ids, batch["input_ids"]), dim=1)
if return_baseline:
baseline_ids = torch.cat((bos_ids, baseline_ids), dim=1)
return BatchEncoding(
input_ids=batch["input_ids"],
input_tokens=[self._convert_ids_to_tokens(x, skip_special_tokens=False) for x in batch["input_ids"]],
attention_mask=batch["attention_mask"],
baseline_ids=baseline_ids,
)
def decode(
self,
ids: list[int] | list[list[int]] | IdsTensor,
skip_special_tokens: bool = True,
) -> list[str]:
return self.tokenizer.batch_decode(
ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=False,
)
def embed_ids(self, ids: IdsTensor, as_targets: bool = False) -> EmbeddingsTensor:
if as_targets and not self.is_encoder_decoder:
raise ValueError("Decoder-only models should use tokenization as source only.")
if self.encoder_int_embeds is not None and not as_targets:
embeddings = self.encoder_int_embeds.indices_to_embeddings(ids)
elif self.decoder_int_embeds is not None and as_targets:
embeddings = self.decoder_int_embeds.indices_to_embeddings(ids)
else:
embeddings = self.get_embedding_layer()(ids)
return embeddings * self.embed_scale
def _convert_ids_to_tokens(
self,
ids: IdsTensor | int,
skip_special_tokens: bool = True,
decode_tokens: bool = False,
) -> OneOrMoreTokenSequences | str:
"""Convert token IDs to token strings.
Args:
ids: Token IDs to convert. Can be a single int, list, or tensor.
skip_special_tokens: Whether to skip special tokens.
decode_tokens: If True, uses tokenizer.decode() for each token to get human-readable
strings. This is especially important for byte-level tokenizers (e.g., Qwen) where
convert_ids_to_tokens returns raw vocabulary entries that may be unreadable.
If False, uses convert_ids_to_tokens which returns raw vocabulary tokens.
Returns:
Token string (single ID) or list of token strings.
"""
# Handle single token ID
if isinstance(ids, int):
if decode_tokens:
return self.tokenizer.decode([ids])
token = self.tokenizer.convert_ids_to_tokens(ids)
return token.decode("utf-8") if isinstance(token, bytes) else token
if decode_tokens:
# Use decode for each token to get human-readable strings
# This handles byte-level tokenizers (like Qwen) correctly
ids_list = ids.tolist() if hasattr(ids, "tolist") else list(ids)
special_ids = set(self.tokenizer.all_special_ids) if skip_special_tokens else set()
return [self.tokenizer.decode([tid]) for tid in ids_list if tid not in special_ids]
# Use convert_ids_to_tokens for raw vocabulary tokens
tokens = self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
if isinstance(tokens, bytes) and not isinstance(tokens, str):
return tokens.decode("utf-8")
if isinstance(tokens, list):
return [t.decode("utf-8") if isinstance(t, bytes) else t for t in tokens]
return tokens
[docs]
def convert_ids_to_tokens(
self,
ids: IdsTensor,
skip_special_tokens: bool | None = True,
decode_tokens: bool = False,
) -> OneOrMoreTokenSequences:
"""Convert token IDs to token strings.
Args:
ids: Token IDs to convert. Can be 1D or 2D tensor.
skip_special_tokens: Whether to skip special tokens.
decode_tokens: If True, uses tokenizer.decode() for each token to get human-readable
strings. This is especially important for byte-level tokenizers (e.g., Qwen) where
convert_ids_to_tokens returns raw vocabulary entries that may be unreadable.
Returns:
List of token strings (1D input) or list of lists of token strings (2D input).
"""
if ids.ndim < 2:
return self._convert_ids_to_tokens(ids, skip_special_tokens, decode_tokens)
return [self._convert_ids_to_tokens(id_slice, skip_special_tokens, decode_tokens) for id_slice in ids]
def convert_tokens_to_ids(self, tokens: TextInput) -> OneOrMoreIdSequences:
if isinstance(tokens[0], str):
return self.tokenizer.convert_tokens_to_ids(tokens)
return [self.tokenizer.convert_tokens_to_ids(token_slice) for token_slice in tokens]
def convert_tokens_to_string(
self,
tokens: OneOrMoreTokenSequences,
skip_special_tokens: bool = True,
as_targets: bool = False,
) -> TextInput:
if isinstance(tokens, list) and len(tokens) == 0:
return ""
elif isinstance(tokens[0], bytes | str):
filtered_tokens = (
tokens if not skip_special_tokens else [t for t in tokens if t not in self.special_tokens]
)
# _decode_use_source_tokenizer was removed in transformers v5.0.0
# For older versions, we temporarily set it to control source/target tokenization
if hasattr(self.tokenizer, "_decode_use_source_tokenizer"):
tmp_decode_state = self.tokenizer._decode_use_source_tokenizer
self.tokenizer._decode_use_source_tokenizer = not as_targets
out_strings = self.tokenizer.convert_tokens_to_string(filtered_tokens)
self.tokenizer._decode_use_source_tokenizer = tmp_decode_state
else:
out_strings = self.tokenizer.convert_tokens_to_string(filtered_tokens)
return out_strings
return [self.convert_tokens_to_string(token_slice, skip_special_tokens, as_targets) for token_slice in tokens]
def convert_string_to_tokens(
self,
text: TextInput,
skip_special_tokens: bool = True,
as_targets: bool = False,
) -> OneOrMoreTokenSequences:
if isinstance(text, str):
ids = self.tokenizer(
text=text if not as_targets else None,
text_target=text if as_targets else None,
add_special_tokens=not skip_special_tokens,
)["input_ids"]
return self._convert_ids_to_tokens(ids, skip_special_tokens)
return [self.convert_string_to_tokens(t, skip_special_tokens, as_targets) for t in text]
[docs]
def clean_tokens(
self,
tokens: OneOrMoreTokenSequences | OneOrMoreTokenWithIdSequences,
skip_special_tokens: bool = False,
as_targets: bool = False,
) -> OneOrMoreTokenSequences:
"""Cleans special characters from tokens.
Args:
tokens (`OneOrMoreTokenSequences`):
A list containing one or more lists of tokens.
skip_special_tokens (`bool`, *optional*, defaults to True):
If true, special tokens are skipped.
as_targets (`bool`, *optional*, defaults to False):
If true, a target tokenizer is used to clean the tokens.
Returns:
`OneOrMoreTokenSequences`: A list containing one or more lists of cleaned tokens.
"""
if isinstance(tokens, list) and len(tokens) == 0:
return []
elif isinstance(tokens[0], bytes | str | TokenWithId):
clean_tokens = []
for tok in tokens:
str_tok = tok.token if isinstance(tok, TokenWithId) else tok
clean_str_tok = self.convert_tokens_to_string(
[str_tok], skip_special_tokens=skip_special_tokens, as_targets=as_targets
)
if not clean_str_tok and tok:
clean_str_tok = tok
clean_tok = TokenWithId(clean_str_tok, tok.id) if isinstance(tok, TokenWithId) else clean_str_tok
clean_tokens.append(clean_tok)
return clean_tokens
return [self.clean_tokens(token_seq, skip_special_tokens, as_targets) for token_seq in tokens]
@property
def special_tokens(self) -> list[str]:
return self.tokenizer.all_special_tokens
@property
def special_tokens_ids(self) -> list[int]:
return self.tokenizer.all_special_ids
@property
def vocabulary_embeddings(self) -> VocabularyEmbeddingsTensor:
return self.get_embedding_layer().weight
def get_embedding_layer(self) -> torch.nn.Module:
return self.model.get_input_embeddings()
[docs]
class HuggingfaceEncoderDecoderModel(HuggingfaceModel, EncoderDecoderAttributionModel):
"""Model wrapper for any ForConditionalGeneration model on the HuggingFace Hub used to enable
feature attribution. Corresponds to AutoModelForSeq2SeqLM auto classes in HF transformers.
Attributes:
model (::obj:`transformers.AutoModelForSeq2SeqLM`):
the model on which attribution is performed.
"""
_autoclass = AutoModelForSeq2SeqLM
def get_encoder(self) -> torch.nn.Module:
return self.model.get_encoder()
def get_decoder(self) -> torch.nn.Module:
return self.model.get_decoder()
@staticmethod
def get_attentions_dict(
output: Seq2SeqLMOutput,
) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.encoder_attentions is None or output.decoder_attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
if output.encoder_attentions is not None:
output.encoder_attentions = tuple(att.to("cpu") for att in output.encoder_attentions)
if output.decoder_attentions is not None:
output.decoder_attentions = tuple(att.to("cpu") for att in output.decoder_attentions)
if output.cross_attentions is not None:
output.cross_attentions = tuple(att.to("cpu") for att in output.cross_attentions)
return {
"encoder_self_attentions": torch.stack(output.encoder_attentions, dim=1),
"decoder_self_attentions": torch.stack(output.decoder_attentions, dim=1),
"cross_attentions": torch.stack(output.cross_attentions, dim=1),
}
@staticmethod
def get_hidden_states_dict(output: Seq2SeqLMOutput) -> dict[str, MultiLayerEmbeddingsTensor]:
return {
"encoder_hidden_states": torch.stack(output.encoder_hidden_states, dim=1),
"decoder_hidden_states": torch.stack(output.decoder_hidden_states, dim=1),
}
[docs]
class HuggingfaceDecoderOnlyModel(HuggingfaceModel, DecoderOnlyAttributionModel):
"""Model wrapper for any ForCausalLM or LMHead model on the HuggingFace Hub used to enable
feature attribution. Corresponds to AutoModelForCausalLM auto classes in HF transformers.
Attributes:
model (::obj:`transformers.AutoModelForCausalLM`):
the model on which attribution is performed.
"""
_autoclass = AutoModelForCausalLM
def __init__(
self,
model: str | PreTrainedModel,
attribution_method: str | None = None,
tokenizer: str | PreTrainedTokenizerBase | None = None,
device: str = None,
model_kwargs: dict[str, Any] | None = {},
tokenizer_kwargs: dict[str, Any] | None = {},
**kwargs,
) -> NoReturn:
super().__init__(model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs)
self.tokenizer.padding_side = "left"
self.tokenizer.truncation_side = "left"
if self.pad_token is None:
self.pad_token = self.tokenizer.bos_token
self.tokenizer.pad_token = self.tokenizer.bos_token
@staticmethod
def get_attentions_dict(output: CausalLMOutput) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
else:
output.attentions = tuple(att.to("cpu") for att in output.attentions)
return {
"decoder_self_attentions": torch.stack(output.attentions, dim=1),
}
@staticmethod
def get_hidden_states_dict(output: CausalLMOutput) -> dict[str, MultiLayerEmbeddingsTensor]:
return {
"decoder_hidden_states": torch.stack(output.hidden_states, dim=1),
}