I implemented some subclass of pytorch tensor in my applications, and need to operate such custom tensors in a call back function hooked to a tensor’s backward pass. However, i found that any operation of such custom tensor would yield torch tensor inside such hooked callbacks, while i am hoping them to return tensor of the custom class.
A minimal example:
import torch class CTensor(torch.Tensor): pass def hook(grad): print('within hook:', type(a[:1])) a = CTensor([1,2,3]).requires_grad_() b = CTensor([1,2,3]).requires_grad_() b.register_hook(hook) c = b.sum() print('outside hook:', type(a[:1])) c.backward()
outside hook: <class '__main__.CTensor'> within hook: <class 'torch.Tensor'>
However, I would need the slicing operation inside the hook() also return tensor of my custom class CTensor. What should I do? Or this is something not possible?