How to set attribute of input tensor in backward?

In Python, you can set attributes to almost any objects:

class A:
    pass
a = A()
a.message = 'hello'
print(a.message)  # prints "hello"

In other words, this works and prints “hello”.

import torch

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.x = x
        return x

    @staticmethod
    def backward(ctx, grad_out):
        x = ctx.x
        x.message = 'hello'
        return torch.eye(x.shape[0])


x = torch.rand(2, 2, requires_grad=True)
MyFunction.apply(x).sum().backward()
print(x.message)

However, according to this, such approach may lead to memory leaks and I have to use save_for_backward.

Thanks for the suggested dict-based solution. However, it doesn’t scale if there are more than one x, because messages would overwrite each other. I am pretty sure that it’s possible to overcome this limitation, however, it all looks too weird to be a good approach :slight_smile:

1 Like