Confused about stored activation memory usage

Could someone explain the memory usage for this block of code?

import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def cuda_memory(msg):
    print("usage after", msg, torch.cuda.memory_allocated(device)/1024**2)

#with torch.no_grad():
with torch.enable_grad():
    dim, rank, outer_product_layers = 768, 3, 4
    vocab_size, seq_len = 10, 20
    inputs = torch.randint(0, vocab_size, (seq_len,))
    cuda_memory("initial") # 0.0
    acts = nn.Embedding(vocab_size, dim)(inputs).to(device)
    cuda_memory("inputs on device") # 0.029
    linear = torch.randn(dim, dim, requires_grad=True).to(device)
    cuda_memory("linear on device") # 2.279
    acts = torch.matmul(acts, linear)
    cuda_memory("linear activations") # 10.404
    for layer in range(outer_product_layers):
        u = torch.randn(dim, rank, requires_grad=True).to(device)
        v = torch.randn(rank, dim, requires_grad=True).to(device)
        cuda_memory(f"u and v on device layer {layer}") # increases ~0.02 each time
        acts = torch.matmul(acts, linear+torch.matmul(u, v))
        #acts = torch.matmul(acts, linear) # memory doesn't increase much
        #acts = torch.matmul(acts, torch.matmul(u, v)) # memory increases about the same amount
        cuda_memory(f"layer {layer} activations") # increases ~2.25 each time

I was attempting a weight-sharing scheme wherein each layer’s weights are a low-rank update added to the previous layer’s weights. Naively, I thought this would save a lot of GPU memory by re-using weight values from the initial linear layer. But it looks like some intermediate values are being saved as well - either the activations or the product of u and v? Is that required in order to calculate the gradients? The memory bump doesn’t happen if I change enable_grad() to no_grad(). It also doesn’t happen if I simply re-use the linear layer at each step; it does happen if I use the outer product at each step without adding linear.

Thanks in advance for any insights.

Edit: Thanks for the formatting advice!

Could you format your code by wrapping it into three backticks ```, please?

1 Like

Done, thanks for the tip.