Where are node values stored?

If I create a network (for example using torch.nn.Sequential), then perform a forward pass, where are the values at the nodes stored ? They are needed for the backward operation to retropropagate the gradient of the error. Are they stored somewhere in the network during the forward pass, or does the backward pass first compute a forward pass to populate those node values in some temporary buffer ?

I am using Pytorch for reinforcement learning, hence I need to go back and forth through various states and I must be sure what state the network is in at any time.

Thank you,

Vincent

Hi,

These values are not stored at the torch.nn level. They are handled by the autograd directly that stores them along with all the informations needed to perform the backward pass.

Thank you for your reply Alban. On this simple example, can you tell me if I get it right ?

import torch

net = torch.nn.Sequential(
    torch.nn.Linear(1, 1, bias=False),
    torch.nn.Linear(1, 1, bias=False)
)

x = torch.tensor([3.])
y = torch.square(net(x))
z = 2 * y

Here, a single computational graph is computed “somewhere” by the autograd package. It contains Function objects. One of these function object can be accessed through y.grad_fn, and another can be accessed through z.grad_fn. The computational graph contains the intermediary node values obtained through the forward pass. When z.backward() is called, the computational graph is used to fill the grad attributes of the parameters in net. Intermediary values are then freed, and a second call to z.backward() will result in an error. Actually, even a call to y.backward() will result in an error, proving that the same computational graph is shared among those two values.

However, the graph is not completely destroyed after the call to backward, as z.grad_fn still exists and points to the MulBackward0 object. If, however, y and z both go out of scope later, then the graph will be completely freed.

Hi,

You are correct indeed. The objects in the graph are now called Node though and not Function.

The reason why not all the graph is destroyed is that the graph itself is very small and on CPU memory so it should never be an issue. What gets freed is only the buffers which are the actual Tensors that use a lot of memory.

1 Like