Hi all,
I’m trying to extend torch.Tensor
as follow:
class MyTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, **kwargs):
tensor = torch.tensor(data, **kwargs)
if tensor.shape != (3, 3):
raise ValueError
return super().__new__(cls, data, **kwargs)
@property
def mid(self):
return self[1, 1]
@mid.setter
def mid(self, val: float):
self[1, 1] = val * torch.ones_like(self[1, 1])
and
class MyTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, **kwargs):
tensor = torch.tensor(data, **kwargs)
if tensor.shape != (3, 3):
raise ValueError
return super().__new__(cls, data, **kwargs)
@property
def mid_low(self):
return self[1, 2]
@ mid_low
def mid_low(self, val: float):
self[1, 2] = val * torch.ones_like(self[1, 2])
following that I have two questions/problems:
- I want to be able to return a
torch.Tensor
object from accessing properties or any sort of slicing. Currently:
tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1.mid
y = tensor1[:, -1]
the types of resulting x
or y
are MyTensor1
. Is it possible that it will be torch.Tensor
(without explicitly using torch.tensor()
here or in the property call)?
- when I try to apply some operations, it seems like inheritance wasn’t good. for example:
tensor1 = MyTensor1([[1,1,1],[2,2,2],[3,3,3]])
tensor2 = MyTensor2([[1,1,1],[2,2,2],[3,3,3]])
x = tensor1 @ tensor2
resulting
TypeError: unsupported operand type(s) for @: 'MyTensor1' and 'MyTensor2'
even though the operation is completely defined.
Thanks for your help!