I’m trying to take advantage of Pytorch’s autograd feature and perform matrix-matrix multiplication $A \times B$ where matrix A is represented as a list of Tensors each on a separate GPU.

What is the best way of distributing this task across multiple GPUs and then collecting the results from each GPU onto one? It doesn’t seem to fit in with the paradigm of torch.nn.DataParallel where one model is replicated on each GPU and the data is passed through the model and then collected.

There isn’t an automatic way to do this. If A is a list of Tensors, each on a separate GPU, I presume A is a large matrix, with rows 0 to i on GPU0, i to j on GPU1, etc.

Let’s assume B is only on GPU 0, because you didn’t mention anything about B.

Here’s a sample snippet showing how to parallelize this operation over multiple GPUs and collect the result on GPU0.

import torch
ngpu = torch.cuda.device_count()
N, M, K = 128, 256, 512 # example matmul sizes
A = []
B = torch.randn(M, K, device='cuda:0')
# randomly initialize A
for i in range(ngpu):
# each GPU has a slice of A
A.append(torch.randn(N // ngpu, M, device='cuda:' + str(i)))
# now let's matmul
# Step 1: make a copy of B on each GPU
B_ = [B]
for i in range(ngpu):
if i != 0:
B_.append(B.to('cuda:' + str(i)))
# Step 2: issue the matmul on each GPU
C_ = []
for i in range(ngpu):
C_.append(torch.matmul(A[i], B_[i]))
C = torch.empty(N, K)
for i in range(ngpu):
start_index = i * (N//ngpu)
C[start_index:start_index+(N//ngpu), :].copy_(C_[i])
# C is the final result gathered on GPU-0

Thanks a lot for responding so quickly! Is there a way of using torch.nn.parallel to matmul in parallel without each GPU waiting for the previous GPU to finish?

yes, in general, each GPU wont wait for the previous GPU at all, unless allocations are occuring (like malloc / free). The CUDA API is asynchronous, so snippets such as:

C_ = []
for i in range(ngpu):
C_.append(torch.matmul(A[i], B_[i]))

happen in parallel.
There are some subtle cases where you have to switch the CUDA stream, but I dont think you’ll run into these cases unless you are doing something very very specific.