Parallelize the application of multiple CNNs to multiple images

Hello!

I’m testing out a rather unusual method that requires the application of n different CNNs to n different images. It is a one-to-one mapping where cnn_1 is applied to img_1, cnn_2 is applied to img_2 and so on. Currently I loop through each of the n networks and images to perform the forward pass, which gets rather slow for large n. I was just wondering if there is a way to parallelize this operation so that all n forwards passes can happen simultaneously?

I hope you have enough GPUs & memory to scale your system well as n increases !

Otherwise, you can use the torch.distributed.launch module. Take a look at this snippet, it might give you a better idea on how you can easily parallelize and even distribute compute using pytorch.

Thanks @LeviViana! If my understanding of torch.distributed.launch is correct, it can be used to distribute compute over multiple GPUs? However, I’d like to parallelize this operation on a single GPU. Ideally, it would be akin to torch.bmm, with the only difference being that instead of a matrix multiplication being batched the CNN forward pass is batched.

Here is a snippet explaining how you can achieve this using grouped convolutions:

import torch

image_1 = torch.rand(1, 3, 50, 50)
image_2 = torch.rand(1, 3, 50, 50)

conv_weight_1 = torch.rand([4, 3, 3, 3]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
conv_weight_2 = torch.rand([4, 3, 3, 3]) # in_channels=3, out_channels=4, kernel_size=(3, 3)

conv_bias_1 = torch.rand([4]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
conv_bias_2 = torch.rand([4]) # in_channels=3, out_channels=4, kernel_size=(3, 3)

# 1st case -> not fused convolutions:

res_1 = torch.nn.functional.conv2d(image_1, conv_weight_1, conv_bias_1)
res_2 = torch.nn.functional.conv2d(image_2, conv_weight_2, conv_bias_2)

res_not_fused = torch.cat((res_1, res_2), dim=0)

# 2nd case -> fused convolutions:

conv_weight = torch.cat((conv_weight_1, conv_weight_2), dim=0)
conv_bias = torch.cat((conv_bias_1, conv_bias_2))

batch = torch.cat((image_1,image_2), dim=1)

res_fused = torch.nn.functional.conv2d(batch, conv_weight, conv_bias, groups=2)
res_fused = res_fused.view(2, 4, 48, 48)

res_fused.allclose(res_not_fused) # <- they are actually equal

I needed to use torch.allclose instead of torch.equal because there is a very small numerical difference between the two methods.

Apologies, I just saw this! Thank you so much for the snippet! :slightly_smiling_face: