from copy import deepcopy
from dataclasses import dataclass, fields
from typing import Any, TypeVar
import numpy as np
import torch
import treescope as ts
from jaxtyping import Int
from ..utils import isnotebook, pretty_dict
TensorClass = TypeVar("TensorClass", bound="TensorWrapper")
[docs]
@dataclass
class TensorWrapper:
"""Wrapper for tensors and lists of tensors to allow for easy access to their attributes."""
@staticmethod
def _getitem(attr, subscript):
if isinstance(attr, torch.Tensor):
if attr.ndim == 1:
return attr[subscript]
if attr.ndim >= 2:
return attr[:, subscript, ...]
elif isinstance(attr, TensorWrapper):
return attr[subscript]
elif isinstance(attr, list) and isinstance(attr[0], list):
return [seq[subscript] for seq in attr]
elif isinstance(attr, dict):
return {key: TensorWrapper._getitem(val, subscript) for key, val in attr.items()}
else:
return attr
@staticmethod
def _slice_batch(attr, subscript):
if isinstance(attr, torch.Tensor):
if attr.ndim == 1:
return attr[subscript]
if attr.ndim >= 2:
return attr[subscript, ...]
elif isinstance(attr, TensorWrapper | list):
return attr[subscript]
elif isinstance(attr, dict):
return {key: TensorWrapper._slice_batch(val, subscript) for key, val in attr.items()}
else:
return attr
@staticmethod
def _select_active(attr, mask):
if isinstance(attr, torch.Tensor):
if attr.ndim <= 1:
return attr
else:
curr_mask = mask.clone()
if curr_mask.dtype != torch.bool:
curr_mask = curr_mask.bool()
while curr_mask.ndim < attr.ndim:
curr_mask = curr_mask.unsqueeze(-1)
orig_shape = attr.shape[1:]
return attr.masked_select(curr_mask).reshape(-1, *orig_shape)
elif isinstance(attr, TensorWrapper):
return attr.select_active(mask)
elif isinstance(attr, list):
return [val for i, val in enumerate(attr) if mask.tolist()[i]]
elif isinstance(attr, dict):
return {key: TensorWrapper._select_active(val, mask) for key, val in attr.items()}
else:
return attr
@staticmethod
def _to(attr, device: str):
if isinstance(attr, torch.Tensor | TensorWrapper):
return attr.to(device)
elif isinstance(attr, dict):
return {key: TensorWrapper._to(val, device) for key, val in attr.items()}
else:
return attr
@staticmethod
def _detach(attr):
if isinstance(attr, torch.Tensor | TensorWrapper):
return attr.detach()
elif isinstance(attr, dict):
return {key: TensorWrapper._detach(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _numpy(attr):
if isinstance(attr, torch.Tensor | TensorWrapper):
np_array = attr.numpy()
if isinstance(np_array, np.ndarray):
return np.ascontiguousarray(np_array, dtype=np_array.dtype)
return np_array
elif isinstance(attr, dict):
return {key: TensorWrapper._numpy(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _torch(attr):
if isinstance(attr, np.ndarray):
return torch.tensor(attr)
elif isinstance(attr, TensorWrapper):
return attr.torch()
elif isinstance(attr, dict):
return {key: TensorWrapper._torch(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool:
try:
if isinstance(self_attr, torch.Tensor):
return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5)
elif isinstance(self_attr, dict):
return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys())
else:
return self_attr == other_attr
except: # noqa: E722
return False
[docs]
def __getitem__(self: TensorClass, subscript) -> TensorClass:
"""By default, idiomatic slicing is used for the sequence dimension across batches.
For batching use `slice_batch` instead.
"""
return self.__class__(
**{field.name: self._getitem(getattr(self, field.name), subscript) for field in fields(self.__class__)}
)
def slice_batch(self: TensorClass, subscript) -> TensorClass:
return self.__class__(
**{field.name: self._slice_batch(getattr(self, field.name), subscript) for field in fields(self.__class__)}
)
def select_active(self: TensorClass, mask: Int[torch.Tensor, "batch_size 1"]) -> TensorClass:
return self.__class__(
**{field.name: self._select_active(getattr(self, field.name), mask) for field in fields(self.__class__)}
)
def to(self: TensorClass, device: str) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._to(attr, device))
if device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
return self
def detach(self: TensorClass) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._detach(attr))
return self
def numpy(self: TensorClass) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._numpy(attr))
return self
def torch(self: TensorClass) -> TensorClass:
for field, val in self.to_dict().items():
setattr(self, field, self._torch(val))
return self
def clone(self: TensorClass) -> TensorClass:
out_params = {}
for field in fields(self.__class__):
attr = getattr(self, field.name)
if isinstance(attr, torch.Tensor | TensorWrapper):
out_params[field.name] = attr.clone()
elif attr is not None:
out_params[field.name] = deepcopy(attr)
else:
out_params[field.name] = None
return self.__class__(**out_params)
def clone_empty(self: TensorClass) -> TensorClass:
out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None}
return self.__class__(**out_params)
def to_dict(self: TensorClass) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def __str__(self):
return f"{self.__class__.__name__}({pretty_dict(self.to_dict())})"
def __repr__(self):
if isnotebook():
ts.display(self)
return ""
return self.__str__()
def __eq__(self, other):
equals = {field: self._eq(val, getattr(other, field)) for field, val in self.__dict__.items()}
return all(x for x in equals.values())
def __json_encode__(self):
return self.clone().detach().to("cpu").numpy().to_dict()
def __json_decode__(self, **attrs):
# Does not contemplate the usage of __slots__
self.__dict__ = attrs
self.__post_init__()
def __post_init__(self):
pass