Facing performance issue for large batch of small matrix multiplication

I am recently facing performance issue for large batch of small matrix multiplication using PyTorch. The running time of the following scripts seems not affected by the channel number. Is there any hint to optimize the running time ?

I have try torch.matmul, torch.bmm, torch.einsum, all the them give the similar results. Would custom cuda kernel help ?

import torch
import torch.cuda
import torch.nn

C = 16
# x = torch.randn((4, 32, 31, 64, 64)).cuda()
x = torch.randn((64*64*4, 31, C)).cuda()
# x = torch.randn((4, 64, 64, 31, 32)).cuda()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

steps = int(1e3)
start.record()
for j in range(steps):
    # w1 = torch.matmul(x, x.transpose(-2,-1))
    w2 = torch.bmm(x, x.transpose(-2,-1))
    # w = torch.einsum('bhwqc,bhwkc->bhwqk', x, x)
    # print(torch.allclose(w1,w2))
    # torch.einsum("bhqd,bhkd->bhqk", test_mat, test_mat)
end.record()

torch.cuda.synchronize()

print('Time: {} ms'.format(start.elapsed_time(end)/steps))
print(torch.cuda.memory_allocated()/(1024))
c=16
Time: 1.3880045166015624 ms
94272.0
c=8
Time: 1.3616466064453125 ms
77888.0
c=2
Time: 1.278853759765625 ms
65472.0

There is no relevant increment in the amount of channels from 2 to 16.

import torch.nn

C = 1000

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

steps = int(100)
with torch.no_grad():
    for C in [10, 100, 1000]:
        x = torch.ones((20000, 31, C), device='cuda:0')
        y = x.transpose(-2, -1)
        start.record()
        for j in range(steps):
            w2 = torch.bmm(x,y )
        end.record()

        torch.cuda.synchronize()
        print(f'Channels: {C}')
        print('\t Time: {} ms'.format(start.elapsed_time(end) / steps))
Channels: 10
	 Time: 1.2575846099853516 ms
Channels: 100
	 Time: 1.9832627868652344 ms
Channels: 1000
	 Time: 9.910529174804687 ms

Also, you need to check the workload of your GPU is 100%
Lastly, your script measures the time it requires to transpose.