I’m creating a custom nn.Module
and I’ve been running to GPU OOM problems because at times the tensors within the module can get quite large. From a memory management perspective I know that using tensor methods like Tensor.view
, Tensor.expand
etc are cheap as they don’t reassign the underlying data structure. To better optimise my code I have been using pytorch_memlab
to inspect the memory footprint.
I have noticed that the following operations do not assign the intermediate tensor c
in memory, i.e. the pointwise multiplication result is not stored in memory, and is instead immediately passed to the torch.sum
function.
def some_module_func(self, a, b):
# First expand a & b using cheap operation expand
a = a.expand(-1, 50, 50, 50, -1, -1)
b = b.expand(-1, 50, 50, 50, -1, -1)
# Then do elementwise multiplication, if this tensor was saved in memory it would be large
c = torch.mul(a, b)
# Then sum over the expanded dimensions
d = torch.sum(c, dim=(1, 2, 3))
return d
However, an operation like torch.gather
seemingly will assign a new tensor in memory as output. Here the expanded tensor a
will be created in memory after the a.gather
operation.
def some_other_module_func(self, a, b, indexer):
a = a.expand(-1, 50, 50, 50, -1, -1)
# Here we rearrange dimension 1 to some predefined order
a = a.gather(dim=1, index=indexer) # this large tensor is written to memory
b = b.expand(-1, 50, 50, 50, -1, -1)
# Then do elementwise multiplication
c = torch.mul(a, b)
# Then sum over the expanded dimensions
d = torch.sum(c, dim=(1, 2, 3))
return d
My question is: in order to better optimise my code, how can I tell which operations/functions are likely to defer memory assignment? Is there any way that I can apply Tensor.gather
for example without the output being explicitly assigned in memory, like with torch.mul
?