In Python, you can set attributes to almost any objects:
class A:
pass
a = A()
a.message = 'hello'
print(a.message) # prints "hello"
In other words, this works and prints “hello”.
import torch
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.x = x
return x
@staticmethod
def backward(ctx, grad_out):
x = ctx.x
x.message = 'hello'
return torch.eye(x.shape[0])
x = torch.rand(2, 2, requires_grad=True)
MyFunction.apply(x).sum().backward()
print(x.message)
However, according to this, such approach may lead to memory leaks and I have to use save_for_backward
.
Thanks for the suggested dict-based solution. However, it doesn’t scale if there are more than one x
, because messages would overwrite each other. I am pretty sure that it’s possible to overcome this limitation, however, it all looks too weird to be a good approach