Imagine I am doing a forward function for an input x (batch size = 32) and get the result y. Is it possible to keep only one row of x and release all other memory? The example code is as follows:
import torch
from torch.nn import functional as F
base = 4000256
params = torch.nn.Parameter(torch.randn(1000, 1000)).to("cuda")
def f(params: torch.nn.Parameter, x, keep, n=10):
# x: (bsz, 1000)
for i in range(n):
x = F.linear(x, params)
print(i, torch.cuda.memory_allocated(0) - base)
result = x[keep]
return result
print(torch.cuda.memory_allocated(0) - base)
x = torch.randn(32, 1000).to("cuda") # 1408000
# x = torch.randn(1, 1000).to("cuda") # 45056
print(torch.cuda.memory_allocated(0) - base)
y = f(params, x, 0)
torch.cuda.empty_cache()
print("e", torch.cuda.memory_allocated(0) - base)
One can find even if I need only one row, it keeps the memory (1408000) for the whole batch. What I expect is: it retains only the memory for one row (as if there is only one row processed, i.e. 45056 in this case).
Is there any possible way to achieve what I want? Thanks in advance
import torch
from torch.nn import functional as F
base = 4000256
params = torch.nn.Parameter(torch.randn(1000, 1000)).to("cuda")
def f(params: torch.nn.Parameter, x, keep, n=10):
# x: (bsz, 1000)
for i in range(n):
x = F.linear(x, params)
print(i, torch.cuda.memory_allocated(0) - base)
result = x[keep, :].clone()
del x
return result
print(torch.cuda.memory_allocated(0) - base)
x = torch.randn(32, 1000).to("cuda") # 1408000
# x = torch.randn(1, 1000).to("cuda") # 45056
print(torch.cuda.memory_allocated(0) - base)
y = f(params, x, 0)
torch.cuda.empty_cache()
print("e", torch.cuda.memory_allocated(0) - base)
The memory is being occupied by the tensors that are saved for backward. You can either run forward with gradients disabled with torch.no_grad(): ... or running .backward() would have the same effect of clearing the saved tensors. After calling .backward(), I see the following on colab:
I need to keep the gradient and backward indeed can clear the saved tensor. However, I want the redundant memory to be cleared before calling backward, since I can only compute the loss and do backward after encoding the whole sequence and the redundant vectors take a lot of memory.
Do you have any idea how? Or can this be achieved?