Many weird phenomena about "torch.matmul()" operation

Informed in advance: this will be a long post, but the phenomena actually confused me these days.


In my recent work, I need to conduct a matrix multiplication operation between two large tensors. The first approach that came to my mind was to leverage “torch.matmul()” function in Pytorch to handle it. However, by conducting many experiments, I think I have came across many weird phenomena. And I think maybe there are much improvement space for “torch.matmul()” operation.

Task Descriptions:

  1. The two large tensors be calculated are A and C, each is 3D tensor.
  2. The size of A is [N, I, J], and the size of C is [M, J, K], where N is not equal to M.
  3. An extra Index tensor (call it Q) is of size [N] and used to query C, thus the queried result is C[Q], which is of size [N, J, K].
  4. Then, we can conduct matmul between A and C[Q].

Strategy:

I design three different strategies for the above task, and make a comparisons with the direct use of “torch.matmul()”.

  • Strategy1: Input A and C[Q], conduct a per value calculation, and at last concatenate them.
out1 = strategy1(a, c[q])

def strategy1(a, c):
    """
    :param a: N * I * J
    :param c: N * J * K
    :return:  N * I * K
    """
    out = torch.cat(
        [torch.matmul(a[i:i + 1, :, :], c[i, :, :]) for i in
         range(a.shape[0])], dim=0)
    return out
  • Strategy2: Input A and C, conduct a per value calculation while querying C with index Q, and at last concatenate them.
out2 = strategy2(a, c, q)

def strategy2(a, c, index):
    """
    :param a: N * I * J
    :param c: N * J * K
    :return:  N * I * K
    """

    out = torch.cat(
        [torch.matmul(a[i:i + 1, :, :], c[index[i], :, :]) for i in
         range(a.shape[0])], dim=0)
    return out
  • Strategy3: Input A and C, conduct a batch value calculation while querying C with index Q, and at last concatenate them.
out3 = strategy3(a, c, q, 256)

def strategy3(a, c, index, batch=256):
    """
    :param a: N * I * J
    :param c: N * J * K
    :return:  N * I * K
    """

    out = torch.cat(
        [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch:i * batch + batch], :, :]) for i in
         range(a.shape[0] // batch)], dim=0)
    return out
  • Main Code:
if __name__ == "__main__":
    q = np.random.randint(0, high=16384 // 8, size=1048576 // 8)
    q = torch.from_numpy(q).long()

    a = torch.rand([1048576 // 8, 1, 64]).cuda()
    c = torch.rand([16384 // 8, 64, 64]).cuda().requires_grad_()

    print("curMem", torch.cuda.memory_allocated() // 1024 // 1024)
    print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

    print("\n=====1=======")
    torch.cuda.synchronize()
    start = time.time()
    out1 = strategy1(a, c[q])
    torch.cuda.synchronize()
    end = time.time()

    print("time", end - start)
    print("shape", out1.shape)
    print("curMem", torch.cuda.memory_allocated() // 1024 // 1024)
    print("maxMem", torch.cuda.max_memory_allocated() // 1024 // 1024)

    print("\n=====2=======")
    torch.cuda.synchronize()
    start = time.time()
    out2 = strategy2(a, c, q)
    torch.cuda.synchronize()
    end = time.time()

    print("time", end - start)
    print("shape", out2.shape)
    print("curMem", torch.cuda.memory_allocated() // 1024 // 1024)
    print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

    print("\n=====3=======")
    torch.cuda.synchronize()
    start = time.time()
    out3 = strategy3(a, c, q, 256)
    torch.cuda.synchronize()
    end = time.time()

    print("time", end - start)
    print("shape", out3.shape)
    print("curMem", torch.cuda.memory_allocated() // 1024 // 1024)
    print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

    print("\n=====4=======")
    torch.cuda.synchronize()
    start = time.time()
    out4 = torch.matmul(a, c[q])
    torch.cuda.synchronize()
    end = time.time()

    print("time", end - start)
    print("shape", out4.shape)
    print("curMem", torch.cuda.memory_allocated() // 1024 // 1024)
    print("max", torch.cuda.max_memory_allocated() // 1024 // 1024)

Weird Phenomena & Questions:

As the results are totally different under different situations. I divide them into two parts.
When A is not requires_grad_(), and no matter whether C is : (Directly-4 is to use torch.matmul() directly)

  • Strategy1: time 6.49213 || curMem 96 || maxMem 2208
  • Strategy2: time 7.27945 || curMem 96 || maxMem 160
  • Strategy3: time 1.37562 || curMem 96 || maxMem 128
  • Directly-4: time 1.37301 || curMem 96 || maxMem 2144

From the above, it can be seen that:

  1. torch.matmul() consumes much memory in its process compared with Strategy2 and Strategy3, though their curMem are the same. It is indeed weird, why torch.matmul() cost that much in the process? One of my hypothesis is that, when query C with index Q, Pytorch will create a copy tensor that lead to memory cause, and after calculation finished, the copy tensor is then released.
  2. The process of Strategy1 is almost the same as that of Strategy2. However, their maxMem is quite different. I have deeply investigated into it, and found it is caused by the input C[Q] in Strategy1. To be specific, when one use X[index_a][index_b], Pytorch may create a copy tensor leading to memory cost, and after finished, release it.

Both the above questions 1 and 2 share another two doubts: Why the tensor is copied as there seems no need of that? Why the copied tensor can be released automatically, as the C[Q] is grad_required?


When A is requires_grad_(), and no matter whether C is : (Directly-4 is to use torch.matmul() directly)

  • Strategy1: time 7.05876 || curMem 2144 || maxMem 2208
  • Strategy2: time 7.99773 || curMem 96 || maxMem 160
  • Strategy3: time 1.44025 || curMem 2144 || maxMem 2176
  • Directly-4: time 1.38115 || curMem 2144 || maxMem 2144

It can be seen that there are even more weird phenomena when A is set to requires_grad_():

  1. The curMem of both Strategy1 and Directly-4 increase, and close to maxMem, this supports the hypothesis in the last part that Pytorch will conduct a tensor copy operation. However, in this part, the copied tensor is not released, but kept. I guess it may because the both input tensors are grad_required that caused it. But again why in the last part, the copied tensor is not released?
  2. Strategy2 seems to be proved a more optimal solution, as it can both keep curMem and maxMem under a low value, though it is quite slow.
  3. Another weird phenomenon is that Strategy3’s curMem also increased, which is hard to be explained by the above hypothesis. I have also deeply investigated into it, and found that it is caused by this: c[index[i * batch:i * batch + batch], :, :], when it is converted to c[index[i * batch], :, :], then the curMem returns to 96 as Strategy2. This actually means that, when A is set requires_grad_(), if want to avoid creating a copied tensor, one should ensure that the second input in torch.matmul() is a 2D tensor, not a batched tensor (3D) ! And this, actually lead to the difficulty to find a solution that can both keep high speed and low memory cost (both curMem and maxMem) when both A and C are requires_grad_.

From the above, there seems existing a lot of weird phenomena related to torch.matmul() for this specific task, if I was not wrong. And it seems that Strategy2 and Strategy3 be more optimal solutions than directly use torch.matmul(), from the conducted experiments.

I hope that more discussions can be under this post, and more people can try the above codes. And I look forward to a solution that can both achieve high speed and low memory cost.

As I am not familiar with the Pytorch underlying source code and how the CUDA operates, if I’ve misunderstood anything, I’m also very grateful to everyone for pointing out the mistakes.