Here is a snippet explaining how you can achieve this using grouped convolutions:
import torch
image_1 = torch.rand(1, 3, 50, 50)
image_2 = torch.rand(1, 3, 50, 50)
conv_weight_1 = torch.rand([4, 3, 3, 3]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
conv_weight_2 = torch.rand([4, 3, 3, 3]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
conv_bias_1 = torch.rand([4]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
conv_bias_2 = torch.rand([4]) # in_channels=3, out_channels=4, kernel_size=(3, 3)
# 1st case -> not fused convolutions:
res_1 = torch.nn.functional.conv2d(image_1, conv_weight_1, conv_bias_1)
res_2 = torch.nn.functional.conv2d(image_2, conv_weight_2, conv_bias_2)
res_not_fused = torch.cat((res_1, res_2), dim=0)
# 2nd case -> fused convolutions:
conv_weight = torch.cat((conv_weight_1, conv_weight_2), dim=0)
conv_bias = torch.cat((conv_bias_1, conv_bias_2))
batch = torch.cat((image_1,image_2), dim=1)
res_fused = torch.nn.functional.conv2d(batch, conv_weight, conv_bias, groups=2)
res_fused = res_fused.view(2, 4, 48, 48)
res_fused.allclose(res_not_fused) # <- they are actually equal
I needed to use torch.allclose instead of torch.equal
because there is a very small numerical difference between the two methods.