# Extending torch.Tensor

Hi all,
I’m trying to extend `torch.Tensor` as follow:

``````class MyTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, **kwargs):
tensor = torch.tensor(data, **kwargs)

if tensor.shape != (3, 3):
raise ValueError

return super().__new__(cls, data, **kwargs)

@property
def mid(self):
return self[1, 1]

@mid.setter
def mid(self, val: float):
self[1, 1] = val * torch.ones_like(self[1, 1])
``````

and

``````class MyTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, **kwargs):
tensor = torch.tensor(data, **kwargs)

if tensor.shape != (3, 3):
raise ValueError

return super().__new__(cls, data, **kwargs)

@property
def mid_low(self):
return self[1, 2]

@ mid_low
def mid_low(self, val: float):
self[1, 2] = val * torch.ones_like(self[1, 2])
``````

following that I have two questions/problems:

1. I want to be able to return a `torch.Tensor` object from accessing properties or any sort of slicing. Currently:
``````tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1.mid
y = tensor1[:, -1]
``````

the types of resulting `x` or `y` are `MyTensor1`. Is it possible that it will be `torch.Tensor` (without explicitly using `torch.tensor()` here or in the property call)?

1. when I try to apply some operations, it seems like inheritance wasn’t good. for example:
``````tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
tensor2 = MyTensor2([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1 @ tensor2
``````

resulting

``````TypeError: unsupported operand type(s) for @: 'MyTensor1' and 'MyTensor2'
``````

even though the operation is completely defined.

Edit:
from docs:
https://pytorch.org/docs/stable/notes/extending.html

``````class MyTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, **kwargs):
tensor = torch.tensor(data, **kwargs)

if tensor.shape != (3, 3):
raise ValueError

return super().__new__(cls, data, **kwargs)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)

@property
def mid(self):
return self[1, 1]

@mid.setter
def mid(self, val: float):
self[1, 1] = val * torch.ones_like(self[1, 1])
``````

but it didn’t solve issue #2

As for issue #1:
I opened a feature request, but wondering if there is an alternative for now…