Hi!
I implemented some subclass of pytorch tensor in my applications, and need to operate such custom tensors in a call back function hooked to a tensor’s backward pass. However, i found that any operation of such custom tensor would yield torch tensor inside such hooked callbacks, while i am hoping them to return tensor of the custom class.
A minimal example:
import torch
class CTensor(torch.Tensor):
pass
def hook(grad):
print('within hook:', type(a[:1]))
a = CTensor([1,2,3]).requires_grad_()
b = CTensor([1,2,3]).requires_grad_()
b.register_hook(hook)
c = b.sum()
print('outside hook:', type(a[:1]))
c.backward()
This prints
outside hook: <class '__main__.CTensor'>
within hook: <class 'torch.Tensor'>
However, I would need the slicing operation inside the hook() also return tensor of my custom class CTensor. What should I do? Or this is something not possible?
Thanks!!