import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any
from ... import list_step_functions
from ...attr.step_functions import is_contrastive_step_function
from ...utils import cli_arg, pretty_dict
from ..attribute import AttributeBaseArgs
from ..commands_utils import command_args_docstring
logger = logging.getLogger(__name__)
class HandleOutputContextSetting(Enum):
MANUAL = "manual"
AUTO = "auto"
PRE = "pre"
@command_args_docstring
@dataclass
class AttributeContextInputArgs:
input_current_text: str = cli_arg(
default="",
help=(
"The input text used for generation. If the model is a decoder-only model, the input text is a "
"prompt used for language modeling. If the model is an encoder-decoder model, the input text is the "
"source text provided as input to the encoder. It will be formatted as {current} in the "
"``input_template``."
),
)
input_context_text: str | None = cli_arg(
default=None,
help=(
"Additional input context influencing the generation of ``output_current_text``. If the model is a"
" decoder-only model, the input text is a prefix to the ``input_current_text`` prompt. If the model is an"
" encoder-decoder model, the input context is part of the source text provided as input to the encoder. "
" It will be formatted as {context} in the ``input_template``."
),
)
input_template: str | None = cli_arg(
default=None,
help=(
"The template used to format model inputs. The template must contain at least the"
" ``{current}`` placeholder, which will be replaced by ``input_current_text``. If ``{context}`` is"
" also specified, input-side context will be used. Can be modified for models requiring special tokens or"
" formatting in the input text (e.g. <brk> tags to separate context and current inputs)."
" Defaults to '{context} {current}' if ``input_context_text`` is provided, '{current}' otherwise."
),
)
output_context_text: str | None = cli_arg(
default=None,
help=(
"An output contexts for which context sensitivity should be detected. For encoder-decoder models, this"
" is a target-side prefix to the output_current_text used as input to the decoder. For decoder-only "
" models this is a portion of the model generation that should be considered as an additional context "
" (e.g. a chain-of-thought sequence). It will be formatted as {context} in the ``output_template``."
" If not provided but specified in the ``output_template``, the output context will be generated"
" along with the output current text, and user validation might be required to separate the two."
),
)
output_current_text: str | None = cli_arg(
default=None,
help=(
"The output text generated by the model when all available contexts are provided. Tokens in "
" ``output_current_text`` will be tested for context-sensitivity, and their generation will be attributed "
" to input/target contexts (if present) in case they are found to be context-sensitive. If specified, "
" this output is force-decoded. Otherwise, it is generated by the model using infilled ``input_template`` "
" and ``output_template``. It will be formatted as {current} in the ``output_template``."
),
)
output_template: str | None = cli_arg(
default=None,
help=(
"The template used to format model outputs. The template must contain at least the"
" ``{current}`` placeholder, which will be replaced by ``output_current_text``. If ``{context}`` is"
" also specified, output-side context will be used. Can be modified for models requiring special tokens or"
" formatting in the output text (e.g. <brk> tags to separate context and current outputs)."
" Defaults to '{context} {current}' if ``output_context_text`` is provided, '{current}' otherwise."
),
)
contextless_input_current_text: str | None = cli_arg(
default=None,
help=(
"The input current text or template to use in the contrastive comparison with contextual input. By default"
" it is the same as ``input_current_text``, but it can be useful in cases where the context is nested "
"inside the current text (e.g. for an ``input_template`` like <user>\n{context}\n{current}\n<assistant> we "
"can use this parameter to format the contextless version as <user>\n{current}\n<assistant>)."
"If it contains the tag {current}, it will be infilled with the ``input_current_text``. Otherwise, it will"
" be used as-is for the contrastive comparison, enabling contrastive comparison with different inputs."
),
)
contextless_output_current_text: str | None = cli_arg(
default=None,
help=(
"The output current text or template to use in the contrastive comparison with contextual output. By default"
" it is the same as ``output_current_text``, but it can be useful in cases where the context is nested "
"inside the current text (e.g. for an ``output_template`` like <user>\n{context}\n{current}\n<assistant> we "
"can use this parameter to format the contextless version as <user>\n{current}\n<assistant>)."
"If it contains the tag {current}, it will be infilled with the ``output_current_text``. Otherwise, it will"
" be used as-is for the contrastive comparison, enabling contrastive comparison with different outputs."
),
)
@command_args_docstring
@dataclass
class AttributeContextMethodArgs(AttributeBaseArgs):
context_sensitivity_metric: str = cli_arg(
default="kl_divergence",
help="The contrastive metric used to detect context-sensitive tokens in ``output_current_text``.",
choices=[fn for fn in list_step_functions() if is_contrastive_step_function(fn)],
)
handle_output_context_strategy: str = cli_arg(
default=HandleOutputContextSetting.MANUAL.value,
choices=[e.value for e in HandleOutputContextSetting],
help=(
"Specifies how output context should be handled when it is produced together with the output current text,"
" and the two need to be separated for context sensitivity detection.\n"
"Options:\n"
"- `manual`: The user is prompted to verify an automatic context detection attempt, and optionally to"
" provide a correct context separation manually.\n"
"- `auto`: Attempts an automatic detection of context using an automatic alignment with source context"
" (assuming an MT-like task).\n"
"- `pre`: If context is required but not pre-defined by the user via the ``output_context_text`` argument,"
" the execution will fail instead of attempting to prompt the user for the output context."
),
)
contextless_output_next_tokens: list[str] = cli_arg(
default_factory=list,
help=(
"If specified, it should provide a list of one token per CCI output indicating the next token that should"
" be force-decoded as contextless output instead of the natural output produced by"
" ``get_contextless_output``. This is ignored if the ``attributed_fn`` used is not contrastive."
),
)
prompt_user_for_contextless_output_next_tokens: bool = cli_arg(
default=False,
help=(
"If specified, the user is prompted to provide the next token that should be force-decoded as contextless"
" output instead of the natural output produced by ``get_contextless_output``. This is ignored if the"
" ``attributed_fn`` used is not contrastive."
),
)
special_tokens_to_keep: list[str] = cli_arg(
default_factory=list,
help="Special tokens to preserve in the generated string, e.g. ``<brk>`` separator between context and current.",
)
decoder_input_output_separator: str = cli_arg(
default=" ",
help=(
"If specified, the separator used to split the input and output of the decoder. If not specified, the"
" separator is a whitespace character."
),
)
context_sensitivity_std_threshold: float = cli_arg(
default=1.0,
help=(
"Parameter to control the selection of ``output_current_text`` tokens considered as context-sensitive for "
"moving onwards with attribution. Corresponds to the number of standard deviations above or below the mean"
" ``context_sensitivity_metric`` score for tokens to be considered context-sensitive."
),
)
context_sensitivity_topk: int | None = cli_arg(
default=None,
help=(
"If set, after selecting the salient context-sensitive tokens with ``context_sensitivity_std_threshold`` "
"only the top-K remaining tokens are used. By default no top-k selection is performed."
),
)
attribution_std_threshold: float = cli_arg(
default=1.0,
help=(
"Parameter to control the selection of ``input_context_text`` and ``output_context_text`` tokens "
"considered as salient as a result for the attribution process. Corresponds to the number of standard "
"deviations above or below the mean ``attribution_method`` score for tokens to be considered salient. "
"CCI scores for all context tokens are saved in the output, but this parameter controls which tokens are "
"used in the visualization of context reliance."
),
)
attribution_topk: int | None = cli_arg(
default=None,
help=(
"If set, after selecting the most salient tokens with ``attribution_std_threshold`` "
"only the top-K remaining tokens are used. By default no top-k selection is performed."
),
)
@command_args_docstring
@dataclass
class AttributeContextOutputArgs:
show_intermediate_outputs: bool = cli_arg(
default=False,
help=(
"If specified, the intermediate outputs produced by the Inseq library for context-sensitive target "
"identification (CTI) and contextual cues imputation (CCI) are shown during the process."
),
)
save_path: str | None = cli_arg(
default=None,
aliases=["-o"],
help="If present, the output of the two-step process will be saved in JSON format at the specified path.",
)
add_output_info: bool = cli_arg(
default=True,
help="If specified, additional information about the attribution process is added to the saved output.",
)
viz_path: str | None = cli_arg(
default=None,
help="If specified, the visualization produced from the output is saved in HTML format at the specified path.",
)
show_viz: bool = cli_arg(
default=True,
help="If specified, the visualization produced from the output is shown in the terminal.",
)
[docs]
@command_args_docstring
@dataclass
class AttributeContextArgs(AttributeContextInputArgs, AttributeContextMethodArgs, AttributeContextOutputArgs):
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"
@classmethod
def _to_dict(cls, val: Any) -> dict[str, Any]:
if val is None or isinstance(val, str | int | float | bool):
return val
elif isinstance(val, dict):
return {k: cls._to_dict(v) for k, v in val.items()}
elif isinstance(val, list | tuple):
return [cls._to_dict(v) for v in val]
else:
return str(val)
def to_dict(self) -> dict[str, Any]:
return self._to_dict(self.__dict__)
def __post_init__(self):
if (
self.handle_output_context_strategy == HandleOutputContextSetting.PRE.value
and not self.output_context_text
and "{context}" in self.output_template
):
raise ValueError(
"If --handle_output_context_strategy='pre' and {context} is used in --output_template, --output_context_text"
" must be specified to avoid user prompt for output context."
)
if len(self.contextless_output_next_tokens) > 0 and self.prompt_user_for_contextless_output_next_tokens:
raise ValueError(
"Only one of contextless_output_next_tokens and prompt_user_for_contextless_output_next_tokens can be"
" specified."
)
if self.input_template is None:
self.input_template = "{current}" if self.input_context_text is None else "{context} {current}"
if self.output_template is None:
self.output_template = "{current}" if self.output_context_text is None else "{context} {current}"
if self.contextless_input_current_text is None:
self.contextless_input_current_text = "{current}"
if "{current}" not in self.contextless_input_current_text:
raise ValueError(
"{current} placeholder is missing from contextless_input_current_text template"
f" {self.contextless_input_current_text}."
)
if self.contextless_output_current_text is None:
self.contextless_output_current_text = "{current}"
if "{current}" not in self.contextless_output_current_text:
raise ValueError(
"{current} placeholder is missing from contextless_output_current_text template"
f" {self.contextless_output_current_text}."
)
self.has_input_context = "{context}" in self.input_template
self.has_output_context = "{context}" in self.output_template
if not self.input_current_text:
raise ValueError("--input_current_text must be a non-empty string.")
if self.input_context_text and not self.has_input_context:
logger.warning(
f"input_template has format {self.input_template} (no {{context}}), but --input_context_text is"
" specified. Ignoring provided --input_context_text."
)
self.input_context_text = None
if self.output_context_text and not self.has_output_context:
logger.warning(
f"output_template has format {self.output_template} (no {{context}}), but --output_context_text is"
" specified. Ignoring provided --output_context_text."
)
self.output_context_text = None
if not self.input_context_text and self.has_input_context:
raise ValueError(
f"{{context}} format placeholder is present in input_template {self.input_template},"
" but --input_context_text is not specified."
)
if "{current}" not in self.input_template:
raise ValueError(f"{{current}} format placeholder is missing from input_template {self.input_template}.")
if "{current}" not in self.output_template:
raise ValueError(f"{{current}} format placeholder is missing from output_template {self.output_template}.")
if not self.input_current_text:
raise ValueError("--input_current_text must be a non-empty string.")
if self.has_output_context and self.output_template.find("{context}") > self.output_template.find("{current}"):
raise ValueError(
f"{{context}} placeholder must appear before {{current}} in output_template '{self.output_template}'."
)
if not self.output_template.endswith("{current}"):
*_, suffix = self.output_template.partition("{current}")
logger.warning(
f"Suffix '{suffix}' was specified in output_template and will be used to ignore the specified suffix"
" tokens during context sensitivity detection. Make sure that the suffix corresponds to the end of the"
" output_current_text by forcing --output_current_text if necessary."
)