How to optimize memory allocation with intermediate tensor operations

I’m creating a custom nn.Module and I’ve been running to GPU OOM problems because at times the tensors within the module can get quite large. From a memory management perspective I know that using tensor methods like Tensor.view, Tensor.expand etc are cheap as they don’t reassign the underlying data structure. To better optimise my code I have been using pytorch_memlab to inspect the memory footprint.

I have noticed that the following operations do not assign the intermediate tensor c in memory, i.e. the pointwise multiplication result is not stored in memory, and is instead immediately passed to the torch.sum function.

def some_module_func(self, a, b):
    # First expand a & b using cheap operation expand
    a = a.expand(-1, 50, 50, 50, -1, -1)
    b = b.expand(-1, 50, 50, 50, -1, -1)
    # Then do elementwise multiplication, if this tensor was saved in memory it would be large
    c = torch.mul(a, b)
    # Then sum over the expanded dimensions
    d = torch.sum(c, dim=(1, 2, 3))
    return d

However, an operation like torch.gather seemingly will assign a new tensor in memory as output. Here the expanded tensor a will be created in memory after the a.gather operation.

def some_other_module_func(self, a, b, indexer):
    a = a.expand(-1, 50, 50, 50, -1, -1)
    # Here we rearrange dimension 1 to some predefined order
    a = a.gather(dim=1, index=indexer) # this large tensor is written to memory

    b = b.expand(-1, 50, 50, 50, -1, -1)
    # Then do elementwise multiplication
    c = torch.mul(a, b)
    # Then sum over the expanded dimensions
    d = torch.sum(c, dim=(1, 2, 3))
    return d

My question is: in order to better optimise my code, how can I tell which operations/functions are likely to defer memory assignment? Is there any way that I can apply Tensor.gather for example without the output being explicitly assigned in memory, like with torch.mul?

I don’t think that’s true as I’m seeing the expected increase in memory:

a = torch.randn(1, 1, 1, 1, 1, 1, device='cuda')
b = torch.randn(1, 1, 1, 1, 1, 1, device='cuda')

print(torch.cuda.memory_allocated())
# 2048

a = a.expand(-1, 50, 50, 50, -1, -1)
b = b.expand(-1, 50, 50, 50, -1, -1)
print(torch.cuda.memory_allocated())
# 2048

c = torch.mul(a, b)
print(torch.cuda.memory_allocated())
# 502272

d = torch.sum(c, dim=(1, 2, 3))
print(torch.cuda.memory_allocated())
# 502784

I don’t think that’s true as I’m seeing the expected increase in memory

Ah, so I did not realise that memory was deallocated when leaving the namespace scope. Thus the memory checks I was performing previously were obfuscating the true memory requirements as I was only checking memory allocation outside the module method.

class Module(torch.nn.Module):
    
    def forward(self, a, b):
        a = a.expand(-1, 50, 50, 50, -1, -1)
        b = b.expand(-1, 50, 50, 50, -1, -1)
        print(torch.cuda.memory_allocated())
        c = torch.mul(a, b)
        print(torch.cuda.memory_allocated())
        d = torch.sum(c, dim=(1, 2, 3))
        print(torch.cuda.memory_allocated())

        return d

module = Module()

d = module(a, b)
# 1024
# 501248
# 501760

print(torch.cuda.memory_allocated())
# 1536

So I guess my main question is: is there any way of not writing the intermediate elementwise multiplication to memory? I trued using torch.utils.checkpoint.checkpoint but still ran into OOM problems.

@mslyon I was wondering if you were able to figure out how to optimize the memory management for your case?

@Abhishek_Tyagi No unfortunately I did not.