importloggingimportwarningsfromdataclassesimportdataclassfrompathlibimportPathimportyamllogger=logging.getLogger(__name__)@dataclassclassModelConfig:"""Configuration used by the methods for which the attribute ``use_model_config=True``. Args: self_attention_module (:obj:`str`): The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in transformers). Can be identified by looking at the name of the self-attention module attribute in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2). cross_attention_module (:obj:`str`): The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models in transformers). Can be identified by looking at the name of the cross-attention module attribute in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`). value_vector (:obj:`str`): The name of the variable in the forward pass of the attention module containing the value vector (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2). """self_attention_module:strvalue_vector:strcross_attention_module:str|None=None# Default configurations for models not in the config fileDEFAULT_DECODER_ONLY_CONFIG=ModelConfig(self_attention_module="attn",value_vector="value",)DEFAULT_ENCODER_DECODER_CONFIG=ModelConfig(self_attention_module="self_attn",cross_attention_module="cross_attention",value_vector="value",)MODEL_CONFIGS={model_type:ModelConfig(**cfg)formodel_type,cfginyaml.safe_load(open(Path(__file__).parent/"model_config.yaml",encoding="utf8")).items()}defget_model_config(model_type:str,is_encoder_decoder:bool=False)->ModelConfig:"""Get the model configuration for the given model type. Args: model_type (`str`): The class name of the model (e.g. ``GPT2LMHeadModel``). is_encoder_decoder (`bool`, *optional*, defaults to False): Whether the model is an encoder-decoder model. Used to determine the default configuration when the model type is not found in the config. Returns: :class:`~inseq.models.ModelConfig`: The model configuration. """ifmodel_typenotinMODEL_CONFIGS:default_config=DEFAULT_ENCODER_DECODER_CONFIGifis_encoder_decoderelseDEFAULT_DECODER_ONLY_CONFIGwarnings.warn(f"Model configuration for '{model_type}' not found. Using default "f"{'encoder-decoder'ifis_encoder_decoderelse'decoder-only'} configuration "f"(self_attention_module='{default_config.self_attention_module}', "f"value_vector='{default_config.value_vector}'"+(f", cross_attention_module='{default_config.cross_attention_module}'"ifis_encoder_decoderelse"")+"). If this doesn't work for your model, you can register a custom configuration with "":meth:`~inseq.register_model_config`, or request it to be added to the library by opening an issue ""on GitHub: https://github.com/inseq-team/inseq/issues",UserWarning,stacklevel=2,)returndefault_configreturnMODEL_CONFIGS[model_type]
[docs]defregister_model_config(model_type:str,config:dict,overwrite:bool=False,allow_partial:bool=False,)->None:"""Allows to register a model configuration for a given model type. The configuration is a dictionary containing information required the methods for which the attribute ``use_model_config=True``. Args: model_type (`str`): The class of the model for which the configuration is registered, used as key in the stored configuration. E.g. GPT2LMHeadModel for the GPT-2 model in HuggingFace Transformers. config (`dict`): A dictionary containing the configuration for the model. The fields should match those of the :class:`~inseq.models.ModelConfig` class. overwrite (`bool`, *optional*, defaults to False): If `True`, the configuration will be overwritten if it already exists. allow_partial (`bool`, *optional*, defaults to False): If `True`, the configuration can be partial, i.e. it can contain only a subset of the fields of the :class:`~inseq.models.ModelConfig` class. The missing fields will be set to `None`. Raises: `ValueError`: If the model type is already registered and `overwrite=False`, or if the configuration is partial and `allow_partial=False`. """ifmodel_typeinMODEL_CONFIGS:ifnotoverwrite:raiseValueError(f"{model_type} is already registered in model configurations.Override with overwrite=True.")logger.warning(f"Overwriting {model_type} config.")all_fields=set(ModelConfig.__dataclass_fields__.keys())config_fields=set(config.keys())diff=all_fields-config_fieldsifdiffandnotallow_partial:raiseValueError(f"Missing fields {','.join(diff)} in model configuration for {model_type}.""Set allow_partial=True to allow partial configuration.")ifallow_partial:config={**dict.fromkeys(diff),**config}MODEL_CONFIGS[model_type]=ModelConfig(**config)