Memory Usage of a Modified Convolution

I want to do a modified version of 1D convolution. Instead of computing the weighted sum, I want to compute the channel-wise distance between the kernel and the data. However, after computing the distance, the gradient takes a lot of memory and even more after backpropagation. Here is a sample code I wrote:

device = torch.device("cuda")
data = torch.rand(128, 5, 1000, device=device)
x = data.unfold(-1, 10, 1).permute(0, 2, 1, 3).unsqueeze(2)
w = torch.rand(200, 5, 10, requires_grad=True, device=device)
print(f"GPU memory usage: {torch.cuda.memory_reserved(0)/1e9:.2f}")
d = (x - w).norm(dim=-1)
print(f"GPU memory usage: {torch.cuda.memory_reserved(0)/1e9:.2f}")
loss = d.mean()
print(f"GPU memory usage: {torch.cuda.memory_reserved(0)/1e9:.2f}")
print(f"GPU memory usage: {torch.cuda.memory_reserved(0)/1e9:.2f}")

I got the following output:

GPU memory usage: 0.53
GPU memory usage: 6.11
GPU memory usage: 6.11
GPU memory usage: 16.39

This is not an issue when using nn.Conv1D() probably because some low-level optimisation. Is there any way to make the distance computation more memory efficient?

Thanks for your help.

By unfolding data, you are creating an intermediate tensor x that is 10 times the size - you materialize that when you do sub, and then promptly reduce it away. One way to reduce memory when you do this is to do it piece of piece and accumulate the summation in the norm as you go. You’ll need to figure out how to manipulate the shapes in the right way though.

Btw you’ll want to use torch.cuda.memory_allocated — PyTorch 2.0 documentation (or peak_memory_allocated) instead of memory_reserved. Even after tensors get cleared, PyTorch will still hold on to that memory rather than return it to CUDA for efficiency reasons, the amount of data that PyTorch holds onto is the value returned by memory reserved whether or not that memory corresponds to actual tensors that are alive.

Thanks a lot for your answer! I checked the memory usage using torch.cuda.max_memory_allocated() and the peak memory usage is indeed very high because of intermediate variables.

My thought is that, mathematically, the distance computation should have similar complexity as typical convolution. Currently, I am trying to learn autograd.Function which seems like a solution to this issue by first computing the gradient by hand and prevent storing intermediate results Extending PyTorch — PyTorch 2.0 documentation

Thanks again for your answer.