Torch.matmul high memory usage

Hi!
The following snipet consumes a lot of memory and I am having trouble understanding why.

with profiler.profile(profile_memory=True) as prof:
    b = torch.rand(1000, 77, 150, 1)
    w = torch.rand(77, 150 , 150)
    out = torch.matmul(w, b)

print(prof.table(sort_by="cpu_memory_usage"))

Produces de following

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::contiguous         0.00%       5.302us        54.39%     790.184ms     790.184ms       6.45 Gb           0 b             1  
      aten::empty_like         0.00%       3.581us         0.00%      15.254us      15.254us       6.45 Gb           0 b             1  
           aten::empty         0.00%      11.673us         0.00%      11.673us      11.673us       6.45 Gb       6.45 Gb             1  
            aten::rand         0.00%      11.606us         5.18%      75.216ms      75.216ms      44.06 Mb           0 b             1  
           aten::empty         0.00%      20.923us         0.00%      20.923us      20.923us      44.06 Mb      44.06 Mb             1  
          aten::matmul        19.57%     284.278ms        94.10%        1.367s        1.367s      44.06 Mb      -6.45 Gb             1  
             aten::bmm        20.13%     292.477ms        20.13%     292.502ms     292.502ms      44.06 Mb           0 b             1  
         aten::resize_         0.00%      17.066us         0.00%      17.066us      17.066us      44.06 Mb      44.06 Mb             1  
            aten::rand         0.00%       7.573us         0.73%      10.539ms      10.539ms       6.61 Mb           0 b             1  
           aten::empty         0.00%      16.171us         0.00%      16.171us      16.171us       6.61 Mb       6.61 Mb             1  
        aten::uniform_         5.18%      75.184ms         5.18%      75.184ms      75.184ms           0 b           0 b             1  
        aten::uniform_         0.72%      10.515ms         0.72%      10.515ms      10.515ms           0 b           0 b             1  
          aten::expand         0.00%      10.284us         0.00%      13.832us      13.832us           0 b           0 b             1  
      aten::as_strided         0.00%       3.548us         0.00%       3.548us       3.548us           0 b           0 b             1  
           aten::copy_        54.39%     790.163ms        54.39%     790.163ms     790.163ms           0 b           0 b             1  
            aten::view         0.00%      13.715us         0.00%      13.715us      13.715us           0 b           0 b             1  
          aten::expand         0.00%       5.178us         0.00%       7.071us       7.071us           0 b           0 b             1  
      aten::as_strided         0.00%       1.893us         0.00%       1.893us       1.893us           0 b           0 b             1  
      aten::contiguous         0.00%       1.301us         0.00%       1.301us       1.301us           0 b           0 b             1  
            aten::view         0.00%       2.690us         0.00%       2.690us       2.690us           0 b           0 b             1  
           aten::empty         0.00%       3.193us         0.00%       3.193us       3.193us           0 b           0 b             1  
          aten::stride         0.00%       0.395us         0.00%       0.395us       0.395us           0 b           0 b             1  
          aten::stride         0.00%       0.180us         0.00%       0.180us       0.180us           0 b           0 b             1  
          aten::stride         0.00%       0.175us         0.00%       0.175us       0.175us           0 b           0 b             1  
          aten::stride         0.00%       0.181us         0.00%       0.181us       0.181us           0 b           0 b             1  
          aten::select         0.00%       2.685us         0.00%       3.194us       3.194us           0 b           0 b             1  
      aten::as_strided         0.00%       0.509us         0.00%       0.509us       0.509us           0 b           0 b             1  
          aten::stride         0.00%       0.179us         0.00%       0.179us       0.179us           0 b           0 b             1  
    aten::_unsafe_view         0.00%       7.630us         0.00%      14.069us      14.069us           0 b           0 b             1  
            aten::view         0.00%       6.439us         0.00%       6.439us       6.439us           0 b           0 b             1  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.453s

As you can see there is an intermediate operation that consumes ~6.5 Gb (!!!) of memory.
Is this normal? Why is that?

Thanks in advance!

batched matmul pre-expands all “batch” dimensions to same sizes, so w tensor is replicated 1000 times.

There are two scenarios:

  1. the operation is expressible with 3d tensors and torch.bmm (backend of matmul)
  2. you’re out of luck, and processing by parts may be required

stub dimension in b suggests case 1, but these dimensions are confusing… 77x1000x150 @ 77x150x150 would perhaps do it