Hi, I’m trying to run a large batch of small mm on CPUs. Here is a minimum working example:
import torch
from torch.profiler import profile, ProfilerActivity
a = torch.randn([64, 32])
b = torch.randn([32, 64])
# c = torch.empty([64, 64])
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
for _ in range(10000):
c = torch.mm(a, b)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_memory_usage", row_limit=100))
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=100))
The output of this snippet is like:
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Input Shapes
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------
aten::mm 99.90% 267.431ms 100.00% 267.711ms 26.771us 130.28 Mb 127.64 Mb 10000 [[64, 32], [32, 64]]
aten::resolve_conj 0.05% 123.000us 0.05% 123.000us 0.012us 2.64 Mb 2.64 Mb 9962 [[64, 64]]
aten::resolve_conj 0.03% 87.000us 0.03% 87.000us 0.009us 0 b 0 b 9975 [[32, 64]]
aten::resolve_conj 0.03% 70.000us 0.03% 70.000us 0.007us 0 b 0 b 10000 [[64, 32]]
[memory] 0.00% 0.000us 0.00% 0.000us 0.000us -130.27 Mb -130.27 Mb 8337 []
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------
Self CPU time total: 267.711ms
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 99.90% 267.431ms 100.00% 267.711ms 26.771us 130.28 Mb 127.64 Mb 10000
aten::resolve_conj 0.10% 280.000us 0.10% 280.000us 0.009us 2.64 Mb 2.64 Mb 29937
[memory] 0.00% 0.000us 0.00% 0.000us 0.000us -130.27 Mb -130.27 Mb 8337
---------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 267.711ms
I found PyTorch seems never reuse the allocated small memory of c (64644 bytes). Instead it allocates new memory for c for all iterations, either every time or in a batch, which accumulates to 130Mb at the end. I have tried having c allocated earlier (the commented line) or changing the mm line to c[:, :] = torch.mm(a, b)
, but it still doesn’t work.
I wonder if there is a way to tell PyTorch to keep reusing the old small memory without re-allocate? Thanks in advance!