Could someone explain the memory usage for this block of code?
import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def cuda_memory(msg):
print("usage after", msg, torch.cuda.memory_allocated(device)/1024**2)
#with torch.no_grad():
with torch.enable_grad():
dim, rank, outer_product_layers = 768, 3, 4
vocab_size, seq_len = 10, 20
inputs = torch.randint(0, vocab_size, (seq_len,))
cuda_memory("initial") # 0.0
acts = nn.Embedding(vocab_size, dim)(inputs).to(device)
cuda_memory("inputs on device") # 0.029
linear = torch.randn(dim, dim, requires_grad=True).to(device)
cuda_memory("linear on device") # 2.279
acts = torch.matmul(acts, linear)
cuda_memory("linear activations") # 10.404
for layer in range(outer_product_layers):
u = torch.randn(dim, rank, requires_grad=True).to(device)
v = torch.randn(rank, dim, requires_grad=True).to(device)
cuda_memory(f"u and v on device layer {layer}") # increases ~0.02 each time
acts = torch.matmul(acts, linear+torch.matmul(u, v))
#acts = torch.matmul(acts, linear) # memory doesn't increase much
#acts = torch.matmul(acts, torch.matmul(u, v)) # memory increases about the same amount
cuda_memory(f"layer {layer} activations") # increases ~2.25 each time
I was attempting a weight-sharing scheme wherein each layer’s weights are a low-rank update added to the previous layer’s weights. Naively, I thought this would save a lot of GPU memory by re-using weight values from the initial linear layer. But it looks like some intermediate values are being saved as well - either the activations or the product of u and v? Is that required in order to calculate the gradients? The memory bump doesn’t happen if I change enable_grad() to no_grad(). It also doesn’t happen if I simply re-use the linear layer at each step; it does happen if I use the outer product at each step without adding linear.
Thanks in advance for any insights.
Edit: Thanks for the formatting advice!