# Copyright 2023 The Inseq Team. All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingfromabcimportabstractmethodimporttorchfromtorch.linalgimportvector_normfrom..utilsimportRegistry,available_classesfrom..utils.typingimport(ScoreTensor,)logger=logging.getLogger(__name__)classAggregationFunction(Registry):registry_attr="aggregation_function_name"def__init__(self):self.takes_single_tensor:bool=Trueself.takes_sequence_scores:bool=False@abstractmethoddef__call__(self,scores:torch.Tensor|tuple[torch.Tensor,...],dim:int,**kwargs,)->ScoreTensor:pass
[docs]deflist_aggregation_functions()->list[str]:"""Lists identifiers for all available aggregation functions."""returnavailable_classes(AggregationFunction)