I am recently facing performance issue for large batch of small matrix multiplication using PyTorch. The running time of the following scripts seems not affected by the channel number. Is there any hint to optimize the running time ?
I have try torch.matmul, torch.bmm, torch.einsum, all the them give the similar results. Would custom cuda kernel help ?
import torch
import torch.cuda
import torch.nn
C = 16
# x = torch.randn((4, 32, 31, 64, 64)).cuda()
x = torch.randn((64*64*4, 31, C)).cuda()
# x = torch.randn((4, 64, 64, 31, 32)).cuda()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
steps = int(1e3)
start.record()
for j in range(steps):
# w1 = torch.matmul(x, x.transpose(-2,-1))
w2 = torch.bmm(x, x.transpose(-2,-1))
# w = torch.einsum('bhwqc,bhwkc->bhwqk', x, x)
# print(torch.allclose(w1,w2))
# torch.einsum("bhqd,bhkd->bhqk", test_mat, test_mat)
end.record()
torch.cuda.synchronize()
print('Time: {} ms'.format(start.elapsed_time(end)/steps))
print(torch.cuda.memory_allocated()/(1024))
c=16
Time: 1.3880045166015624 ms
94272.0
c=8
Time: 1.3616466064453125 ms
77888.0
c=2
Time: 1.278853759765625 ms
65472.0