Debugging Large Memory Usage

Hi there, I’m currently working on a piece of code that had quite a few places of torch.repeat_interleave() to repeat data as part of the neural network. I’ve been trying to figure out where I can optimize the code, in terms of replacing these with torch.expand. However, I’m still facing the issue of out of memory. Is there a way to decipher which operations are using large amounts of memory? I’ve done my best in terms of using torch.gather to index and torch.expand to avoid allocating new memory for tensors. I suspect the issue might lie in the computational graph but I have no idea how to check which operations are causing the large memory use.

You could add print statements to your code and check the memory usage after some operations e.g. via print(torch.cuda.memory_summary()) or print(torch.cuda.memory_allocated()).
Using expand will not allocate new memory, but note that some operations might depend on contiguous data and could call tensor.contiguous() internally which will then allocate the memory.