Fail to make use of the Pytorch GPU asynchronous operation

I am testing the following codes. I want to improve the speed by segmenting the data into “batches”, hoping to make use of the latent asynchronous operation. However, when I set the “batch” value to 16 or even larger, the time is almost the same as the value 1. I have no idea why the calculation time didn’t reduce. Is there any misunderstandings about the asynchronous operation?

def improved_efficient_matmul(a, c, index, batch=2):
    """
    :param a: N * I * J
    :param b: N * J * K
    :return:  N * I * K
    """

    per_batch_len = a.shape[0] // batch
    tmp = {}

    for b in range(batch):
        tmp[b] = torch.cat([torch.matmul(a[i + per_batch_len * b:i + per_batch_len * b+1, :, :], c[index[i + per_batch_len * b], :, :]) for i in
                            range(per_batch_len)], dim=0)
    tmp = dict(sorted(tmp.items(), key=lambda x: x[0]))
    out = []
    for k in tmp.keys():
        out.append(tmp[k])
    out = torch.cat(out, dim=0)
    return out

rad = np.random.randint(0, high=16384, size=1048576)
rad = torch.from_numpy(rad).long()

a = torch.rand([1048576, 1, 64]).cuda()
b = torch.rand([16384, 64, 64]).cuda().requires_grad_()

print(torch.cuda.memory_allocated() // 1024 // 1024)
print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

torch.cuda.synchronize()
start = time.time()
out1 = improved_efficient_matmul(a, b, rad, 1)
torch.cuda.synchronize()
end = time.time()

print(end - start)
print(torch.cuda.memory_allocated() // 1024 // 1024)
print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

print(out1.shape)

This post and the GTC presentation linked in my other post in the same thread might be interesting.
TL;DR: use streams and make sure compute resources are available. Matmuls tend to have a high occupancy, so you might not be able to overlap them.

Why do you think you can improve the PyTorch torch.matmul operation even further if you use batches? Where do you find in theory this is worth trying?

Have you tried torch.bmm?

Have you tried the matrix multiplication Θ(n^3) that is much faster than Θ(n^ 2.3728596) given by Vavrinka and Isner?

Thanks for your reply. The post help a lot.

Many thanks for your questions. I actually want to do matmul for two large tensor, which, however, can cause OOD in my device if use torch.matmul(A, B) directly. Thus, I want to reduce the memory cost by doing matmul for each single value while keeping the speed.

I have just found out where my misunderstanding is in the above codes. And just break the two tensors into batch can achieve what I want. That is: torch.matmul(A[batch, …], B[index[batch], …]) (The above code uses batch=1, thus very slow). Previously, I was worried that B[index[batch]] might lead to a new copy tensor which would increase the GPU memory cost. But it seems it won’t.

bmm looks ideal for your case why not using just that. matmul is just general way which will eventually end in bmm if I recall?

Also what is OOD ? Out of democracy? Have you meant to say out of memory or out of kernels?

Sorry for the typo. It is “out of memory”. Actually, torch.bmm cannot benefit much in my task. After conducting more experiments today, I have rearranged the weird phenomena I came across and post it more detailed here : https://discuss.pytorch.org/t/many-weird-phenomena-about-torch-matmul-operation/158208