I am trying to inherit torch.Tensor class to define a new type of MyTensor class which has additional attributes rather than those defined in torch.Tensor class. Please refer to the reference topic: LINK.
But, the problem is that some functions such as torch.Tensor.permute(), torch.Tensor.view(), torch.Tensor.transpose() and others return a new tensor instance which does not include the values of additional attributes. For instance,
a = MyTensor([1.], extra_attr="a")
b = a.view(-1)
then, though type(b) is MyTensor, it does not inherit extra_attr attributes. So, I am trying to overwrite these functions, but I cannot find the source code for these functions and cannot know the exact arguments for those functions.
awesome! That’s what I’ve been looking for. thanks.
One more question about .view() implementations. Then, how can .view() function be bound as python class method under torch.Tensor class? I guess that binding just Tensor class in ATen torch c++ frontend implementation is enough?