How to set attribute of input tensor in backward?

My goal is to set some attribute of a tensor that was saved for backward. But the problem is that ctx.save_for_backward saves some new object, not the original one. What is the correct way of passing information from backward?

import torch

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('>>> forward')
        print(f'id(x)         {id(x)}')
        print(f'id(x.data) {id(x.data)}')
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_out):
        x = ctx.saved_tensors[0]
        print('>>> backward')
        print(f'id(x)         {id(x)}')
        print(f'id(x.data) {id(x.data)}')
        x.message = 'hello'
        return torch.eye(x.shape[0])


x = torch.rand(2, 2, requires_grad=True)
MyFunction.apply(x).sum().backward()
print(x.message)

Output:

>>> forward
id(x)         139894137130512
id(x.data) 139894137128640
>>> backward
id(x)         139894137130800
id(x.data) 139894137130872

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-b19465919e2e> in <module>
     22 x = torch.rand(2, 2, requires_grad=True)
     23 MyFunction.apply(x).sum().backward()
---> 24 print(x.message)

AttributeError: 'Tensor' object has no attribute 'message'
2 Likes

Hello,

I think it is just because Tensor object has no attribute message.
How about create a dict at the very first and change this line:

x.message --> dict['x'] = 'message'

Is there any better solutions? Let me know, thank you.

In Python, you can set attributes to almost any objects:

class A:
    pass
a = A()
a.message = 'hello'
print(a.message)  # prints "hello"

In other words, this works and prints “hello”.

import torch

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.x = x
        return x

    @staticmethod
    def backward(ctx, grad_out):
        x = ctx.x
        x.message = 'hello'
        return torch.eye(x.shape[0])


x = torch.rand(2, 2, requires_grad=True)
MyFunction.apply(x).sum().backward()
print(x.message)

However, according to this, such approach may lead to memory leaks and I have to use save_for_backward.

Thanks for the suggested dict-based solution. However, it doesn’t scale if there are more than one x, because messages would overwrite each other. I am pretty sure that it’s possible to overcome this limitation, however, it all looks too weird to be a good approach :slight_smile:

1 Like