Extending torch.Tensor

Hi all,
I’m trying to extend torch.Tensor as follow:

class MyTensor1(torch.Tensor):
    @staticmethod
    def __new__(cls, data, **kwargs):
        tensor = torch.tensor(data, **kwargs)

        if tensor.shape != (3, 3):
            raise ValueError

        return super().__new__(cls, data, **kwargs)

    @property
    def mid(self):
        return self[1, 1]

    @mid.setter
    def mid(self, val: float):
        self[1, 1] = val * torch.ones_like(self[1, 1])

and

class MyTensor2(torch.Tensor):
    @staticmethod
    def __new__(cls, data, **kwargs):
        tensor = torch.tensor(data, **kwargs)

        if tensor.shape != (3, 3):
            raise ValueError

        return super().__new__(cls, data, **kwargs)

    @property
    def mid_low(self):
        return self[1, 2]

    @ mid_low
    def mid_low(self, val: float):
        self[1, 2] = val * torch.ones_like(self[1, 2])

following that I have two questions/problems:

  1. I want to be able to return a torch.Tensor object from accessing properties or any sort of slicing. Currently:
tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1.mid
y = tensor1[:, -1]

the types of resulting x or y are MyTensor1. Is it possible that it will be torch.Tensor (without explicitly using torch.tensor() here or in the property call)?

  1. when I try to apply some operations, it seems like inheritance wasn’t good. for example:
tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
tensor2 = MyTensor2([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1 @ tensor2

resulting

TypeError: unsupported operand type(s) for @: 'MyTensor1' and 'MyTensor2'

even though the operation is completely defined.

Thanks for your help!

Edit:
from docs:
https://pytorch.org/docs/stable/notes/extending.html

I tried to add

class MyTensor1(torch.Tensor):
    @staticmethod
    def __new__(cls, data, **kwargs):
        tensor = torch.tensor(data, **kwargs)

        if tensor.shape != (3, 3):
            raise ValueError

        return super().__new__(cls, data, **kwargs)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

    @property
    def mid(self):
        return self[1, 1]

    @mid.setter
    def mid(self, val: float):
        self[1, 1] = val * torch.ones_like(self[1, 1])

but it didn’t solve issue #2

As for issue #1:
I opened a feature request, but wondering if there is an alternative for now…