How to tell PyTorch to not allocate new memory and reuse old memory?

Hi, I’m trying to run a large batch of small mm on CPUs. Here is a minimum working example:

import torch
from torch.profiler import profile, ProfilerActivity

a = torch.randn([64, 32])
b = torch.randn([32, 64])
# c = torch.empty([64, 64])
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    for _ in range(10000):
        c = torch.mm(a, b)
    
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_memory_usage", row_limit=100))
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=100))

The output of this snippet is like:

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls              Input Shapes  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------------------  
              aten::mm        99.90%     267.431ms       100.00%     267.711ms      26.771us     130.28 Mb     127.64 Mb         10000      [[64, 32], [32, 64]]  
    aten::resolve_conj         0.05%     123.000us         0.05%     123.000us       0.012us       2.64 Mb       2.64 Mb          9962                [[64, 64]]  
    aten::resolve_conj         0.03%      87.000us         0.03%      87.000us       0.009us           0 b           0 b          9975                [[32, 64]]  
    aten::resolve_conj         0.03%      70.000us         0.03%      70.000us       0.007us           0 b           0 b         10000                [[64, 32]]  
              [memory]         0.00%       0.000us         0.00%       0.000us       0.000us    -130.27 Mb    -130.27 Mb          8337                        []  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------------------  
Self CPU time total: 267.711ms

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::mm        99.90%     267.431ms       100.00%     267.711ms      26.771us     130.28 Mb     127.64 Mb         10000  
    aten::resolve_conj         0.10%     280.000us         0.10%     280.000us       0.009us       2.64 Mb       2.64 Mb         29937  
              [memory]         0.00%       0.000us         0.00%       0.000us       0.000us    -130.27 Mb    -130.27 Mb          8337  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 267.711ms

I found PyTorch seems never reuse the allocated small memory of c (64644 bytes). Instead it allocates new memory for c for all iterations, either every time or in a batch, which accumulates to 130Mb at the end. I have tried having c allocated earlier (the commented line) or changing the mm line to c[:, :] = torch.mm(a, b), but it still doesn’t work.

I wonder if there is a way to tell PyTorch to keep reusing the old small memory without re-allocate? Thanks in advance!

There is an out parameter in torch.mm() call which can help in this case.

https://pytorch.org/docs/stable/generated/torch.mm.html

 torch.mm(a, b, out=c)

This solves my problem perfectly. Thank you very much @InnovArul !