Where can I find tensor.view() source code?


I am trying to inherit torch.Tensor class to define a new type of MyTensor class which has additional attributes rather than those defined in torch.Tensor class. Please refer to the reference topic: LINK.

But, the problem is that some functions such as torch.Tensor.permute(), torch.Tensor.view(), torch.Tensor.transpose() and others return a new tensor instance which does not include the values of additional attributes. For instance,

a = MyTensor([1.], extra_attr="a")
b = a.view(-1)

then, though type(b) is MyTensor, it does not inherit extra_attr attributes. So, I am trying to overwrite these functions, but I cannot find the source code for these functions and cannot know the exact arguments for those functions.

To directly answer your question, the source code can be found here pytorch/TensorShape.cpp at master · pytorch/pytorch · GitHub

You might be better of just looking at function signatures though which can be all be found in pytorch/native_functions.yaml at 1d08b5b1034c273fb64123c095aa2ab41eae3ea0 · pytorch/pytorch · GitHub

Also, are you looking at Extending PyTorch — PyTorch 2.0 documentation for subclassing Tensor?

awesome! That’s what I’ve been looking for. thanks.

One more question about .view() implementations. Then, how can .view() function be bound as python class method under torch.Tensor class? I guess that binding just Tensor class in ATen torch c++ frontend implementation is enough?

yeah, the torch.Tensor class inherits from a codegened tensor class written in cpp