The latest pytorch v2 transforms allow creating custom TVTensors (see here). which allows registering functional transforms specific to a TVTensor type.
But what about applying the default F.transform on my custom TVTensor? I would need to convert TVTensor → torch.Tensor first, but I think the way I’m doing it creates a copy and torch.tensor()
is deprecated.
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
class DepthmapVideoTVTensor(tv_tensors.TVTensor):
"""
Modified from the Video TVTensor here:
https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_video.py
"""
def __new__(cls, data: Any, *, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim != 4: #(seq_len, C, H, W)
raise ValueError
return tensor.as_subclass(cls)
def __repr__(self, *, tensor_contents: Any = None) -> str:
return self._make_repr()
@F.register_kernel(functional=F.pad, tv_tensor_cls=DepthmapVideoTVTensor)
def pad_depthmap(depthmap, *args, **kwargs):
print("Padding Depthmap!")
depthmap_new = torch.tensor(depthmap)
depthmap_new = F.pad(depthmap_new, *args, **kwargs)
return tv_tensors.wrap(depthmap_new, like=depthmap)
Any better ideas? Thanks!