Is it possible to keep memory for only one row in a batch

Imagine I am doing a forward function for an input x (batch size = 32) and get the result y. Is it possible to keep only one row of x and release all other memory? The example code is as follows:

import torch
from torch.nn import functional as F

base = 4000256
params = torch.nn.Parameter(torch.randn(1000, 1000)).to("cuda")


def f(params: torch.nn.Parameter, x, keep, n=10):
    # x: (bsz, 1000)
    for i in range(n):
        x = F.linear(x, params)
        print(i, torch.cuda.memory_allocated(0) - base)
    result = x[keep]
    return result


print(torch.cuda.memory_allocated(0) - base)
x = torch.randn(32, 1000).to("cuda")  # 1408000
# x = torch.randn(1, 1000).to("cuda")  # 45056
print(torch.cuda.memory_allocated(0) - base)
y = f(params, x, 0)
torch.cuda.empty_cache()
print("e", torch.cuda.memory_allocated(0) - base)

It prints:

0
128000
0 256000
1 384000
2 512000
3 640000
4 768000
5 896000
6 1024000
7 1152000
8 1280000
9 1408000
e 1408000

One can find even if I need only one row, it keeps the memory (1408000) for the whole batch. What I expect is: it retains only the memory for one row (as if there is only one row processed, i.e. 45056 in this case).

Is there any possible way to achieve what I want? Thanks in advance

Maybe you could do the following?

y = x[i,:].clone()
del x

Thanks! I tried your idea

import torch
from torch.nn import functional as F

base = 4000256
params = torch.nn.Parameter(torch.randn(1000, 1000)).to("cuda")


def f(params: torch.nn.Parameter, x, keep, n=10):
    # x: (bsz, 1000)
    for i in range(n):
        x = F.linear(x, params)
        print(i, torch.cuda.memory_allocated(0) - base)
    result = x[keep, :].clone()
    del x
    return result


print(torch.cuda.memory_allocated(0) - base)
x = torch.randn(32, 1000).to("cuda")  # 1408000
# x = torch.randn(1, 1000).to("cuda")  # 45056
print(torch.cuda.memory_allocated(0) - base)
y = f(params, x, 0)
torch.cuda.empty_cache()
print("e", torch.cuda.memory_allocated(0) - base)

and it prints:

0
128000
0 256000
1 384000
2 512000
3 640000
4 768000
5 896000
6 1024000
7 1152000
8 1280000
9 1408000
e 1284096

1284096 is a bit lower than 1408000, but ideally, it should reduce itself to 45056 (i.e. when batch size = 1 since the very beginning)

The memory is being occupied by the tensors that are saved for backward. You can either run forward with gradients disabled with torch.no_grad(): ... or running .backward() would have the same effect of clearing the saved tensors. After calling .backward(), I see the following on colab:

0
128000
0 256000
1 384000
2 512000
3 640000
4 768000
5 896000
6 1024000
7 1152000
8 1280000
9 1408000
e 132096

Thanks!

I need to keep the gradient and backward indeed can clear the saved tensor. However, I want the redundant memory to be cleared before calling backward, since I can only compute the loss and do backward after encoding the whole sequence and the redundant vectors take a lot of memory.

Do you have any idea how? Or can this be achieved?