# How to optimize memory allocation with intermediate tensor operations

I’m creating a custom `nn.Module` and I’ve been running to GPU OOM problems because at times the tensors within the module can get quite large. From a memory management perspective I know that using tensor methods like `Tensor.view`, `Tensor.expand` etc are cheap as they don’t reassign the underlying data structure. To better optimise my code I have been using `pytorch_memlab` to inspect the memory footprint.

I have noticed that the following operations do not assign the intermediate tensor `c` in memory, i.e. the pointwise multiplication result is not stored in memory, and is instead immediately passed to the `torch.sum` function.

``````def some_module_func(self, a, b):
# First expand a & b using cheap operation expand
a = a.expand(-1, 50, 50, 50, -1, -1)
b = b.expand(-1, 50, 50, 50, -1, -1)
# Then do elementwise multiplication, if this tensor was saved in memory it would be large
c = torch.mul(a, b)
# Then sum over the expanded dimensions
d = torch.sum(c, dim=(1, 2, 3))
return d
``````

However, an operation like `torch.gather` seemingly will assign a new tensor in memory as output. Here the expanded tensor `a` will be created in memory after the `a.gather` operation.

``````def some_other_module_func(self, a, b, indexer):
a = a.expand(-1, 50, 50, 50, -1, -1)
# Here we rearrange dimension 1 to some predefined order
a = a.gather(dim=1, index=indexer) # this large tensor is written to memory

b = b.expand(-1, 50, 50, 50, -1, -1)
# Then do elementwise multiplication
c = torch.mul(a, b)
# Then sum over the expanded dimensions
d = torch.sum(c, dim=(1, 2, 3))
return d
``````

My question is: in order to better optimise my code, how can I tell which operations/functions are likely to defer memory assignment? Is there any way that I can apply `Tensor.gather` for example without the output being explicitly assigned in memory, like with `torch.mul`?

I don’t think that’s true as I’m seeing the expected increase in memory:

``````a = torch.randn(1, 1, 1, 1, 1, 1, device='cuda')
b = torch.randn(1, 1, 1, 1, 1, 1, device='cuda')

print(torch.cuda.memory_allocated())
# 2048

a = a.expand(-1, 50, 50, 50, -1, -1)
b = b.expand(-1, 50, 50, 50, -1, -1)
print(torch.cuda.memory_allocated())
# 2048

c = torch.mul(a, b)
print(torch.cuda.memory_allocated())
# 502272

d = torch.sum(c, dim=(1, 2, 3))
print(torch.cuda.memory_allocated())
# 502784
``````

I don’t think that’s true as I’m seeing the expected increase in memory

Ah, so I did not realise that memory was deallocated when leaving the namespace scope. Thus the memory checks I was performing previously were obfuscating the true memory requirements as I was only checking memory allocation outside the module method.

``````class Module(torch.nn.Module):

def forward(self, a, b):
a = a.expand(-1, 50, 50, 50, -1, -1)
b = b.expand(-1, 50, 50, 50, -1, -1)
print(torch.cuda.memory_allocated())
c = torch.mul(a, b)
print(torch.cuda.memory_allocated())
d = torch.sum(c, dim=(1, 2, 3))
print(torch.cuda.memory_allocated())

return d

module = Module()

d = module(a, b)
# 1024
# 501248
# 501760

print(torch.cuda.memory_allocated())
# 1536
``````

So I guess my main question is: is there any way of not writing the intermediate elementwise multiplication to memory? I trued using `torch.utils.checkpoint.checkpoint` but still ran into OOM problems.