Deleting Tensors in Context Save for Backward

Are the tensor saved for backward as below freed or deleted automatically after the backward pass?

ctx.save_for_backward(input, weight, bias)

I am trying to get around memory used problems.

Yes, these tensors should be freed after the backward().
To double check it, you could use this example and add some print statements to check the memory:

for t in range(5):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    print('='*10, t, '='*10)
    print(torch.cuda.memory_allocated()/1024)
    y_pred = relu(x.mm(w1)).mm(w2)
    print(torch.cuda.memory_allocated()/1024)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(torch.cuda.memory_allocated()/1024)
    if t % 100 == 99:
        print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()
    print(torch.cuda.memory_allocated()/1024)

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad = None
        w2.grad = None
    
    print(torch.cuda.memory_allocated()/1024)

which shows that the memory falls down to the initial usage:

========== 0 ==========
647.5
700.0
703.0
1045.5
650.5
========== 1 ==========
650.5
700.5
703.0
1045.5
650.5
========== 2 ==========
650.5
700.5
703.0
1045.5
650.5
========== 3 ==========
650.5
700.5
703.0
1045.5
650.5
========== 4 ==========
650.5
700.5
703.0
1045.5
650.5

Note that I’ve replaced the .grad attributes with None to free these tensors as well instead of zeroing them out.

1 Like

Excellent, thank you so much. I had not thought yet about assigning the .grad to None after the backward. I totally missed that!

Currently, I was only deleting tensors whenever they were not needed anymore, such as in X = torch.fft.fft2d(x); del x;. Deleting the model weight’s .grad will definitely improve things!

Thanks a lot!

In case you are using an optimizer or are calling zero_grad on the module itself, note that you can use the set_to_none=True argument in zero_grad for the exact same reason of saving memory. :slight_smile:

Thank you so much again for these precious tips. I just had another question on this topic. Is there a way to free the tensors saved for backwards or the grad_output before the end of backward?

Say I have something like:

def backward(cls, ctx, grad_output):
.
.
.
   del grad_output;
.
.
.

I imagine that the above is pointless since there would still be a reference for grad_output on the call stack, right? Would there be a way do to something like this?

It doesn’t seem to make a difference, if I del the tensor and check the memory before and after this operation.

Yes, I totally agree.

As far as I know, Python is heavily based on references instead of pointers and it manages memory allocation on its own. But, I wonder if there would be away to do something inside backward that would invalidate the reference in the call stack, or somehow just forcing a memory deallocation. I am just realizing now this would be a OFF-(Torch)-Topic question.

Tensors saved for the backward pass from forward ARE actually freed automatically during the corresponding backward!

If you are curious you can check in the autograd engine’s evaluate function, we call fn.release_variables if keep_graph (i.e., the retain_graph parameter) is set to False.

Thanks for showing that piece of code. It’s very nice to know where that happens. But that was already pointed by @ptrblck. Now I am wondering if there is a way to free these tensors (the ones saved. for backwards) and also (if possible) the grad_output somewhere inside the backward function.

One of the issues I face is that the calculations I do in the backward use a lot of memory. It is not really that much considering only a single layer. But, considering that some CNN architecture can get really deep and wide, I just run out of memory in such scenarios. I am looking for releasing those tensors right after using then, while still having other computations left.

Maybe my only hope for those scenarios is to be heavily based on in-place operations.