My goal is to set some attribute of a tensor that was saved for backward. But the problem is that ctx.save_for_backward
saves some new object, not the original one. What is the correct way of passing information from backward?
import torch
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
print('>>> forward')
print(f'id(x) {id(x)}')
print(f'id(x.data) {id(x.data)}')
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_out):
x = ctx.saved_tensors[0]
print('>>> backward')
print(f'id(x) {id(x)}')
print(f'id(x.data) {id(x.data)}')
x.message = 'hello'
return torch.eye(x.shape[0])
x = torch.rand(2, 2, requires_grad=True)
MyFunction.apply(x).sum().backward()
print(x.message)
Output:
>>> forward
id(x) 139894137130512
id(x.data) 139894137128640
>>> backward
id(x) 139894137130800
id(x.data) 139894137130872
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-b19465919e2e> in <module>
22 x = torch.rand(2, 2, requires_grad=True)
23 MyFunction.apply(x).sum().backward()
---> 24 print(x.message)
AttributeError: 'Tensor' object has no attribute 'message'