Speeding up torch.matmul

Hi, I would like to compute the matrix multiplication for two matrices.
Using torch.matmul could get correct result but the speed is slow.
Could you please give me some adavise to speed the matrix multiplication?
I use the following code the measure the time.

cols = torch.randn(16,57600,1,108).cuda()
local_weight = torch.randn(16,57600,108,3).cuda()
    with torch.no_grad():
        for i in range(10):
            torch.cuda.synchronize()
            t0 = time.time()
            c = cols.matmul(local_weight)#torch.matmul(cols,local_weight)
            torch.cuda.synchronize()
            print(time.time() - t0)

Calculating the torch.matmul takes 5 ms. Is it possible to make it more faster?

Hi Huang!

This sounds quite fast to me.

Please note that the (two-dimensional) batch of matrix multiplications that
you are performing is quite large.

In your example, one call to matmul() performs just a little under 600
million floating-point operations. Doing so in 5 milliseconds corresponds
to 120 gigaflops. That’s very good.

In principle, you could further parallelize you batch of matrix multiplications
across multiple gpus, but whether that would give you improved throughput
for your actual use case would depend on a lot of details.

(As an aside, I can basically reproduce your timings, except that I get about
50 mS per matmul() on my not-very-powerful GeForce GTX 1050 Ti, for a
speed of 12 gigaflops.)

Best.

K. Frank