CNN with multiple images as input

I want to have a network that has multiple images as inputs and for each image, a CNN with the same architecture but different parameters is used.
Lets say that each batch has 10 images of shape 3224224 and i want each image to go through a CNN, with the same architecture for all images but with different parameters

You can do two things to accomplish this:

  1. Concat the images along the channels dim.
  2. Set the Conv2d groups variable to 10(in this case), and the in_channels to 30(3 channels and 10 images).

Alternatively, you could just do something like this:

class CNN(nn.Module):
    def __init__(self, device = "cpu"):
        super().__init__()
        self.layer1 = nn.ModuleList()
        self.dropout = nn.Dropout(0.1)
        for i in range(10):
            self.layer1.append(nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), self.dropout))
        self.device = device
    def forward(self, x):
        output = torch.empty((0,64,224,224), device=self.device)
        for i in range(10):
            output = torch.cat([output, self.layer1[i](x[i,...]).unsqueeze(0)])
        return output

dummy_data = torch.rand((10, 3, 224, 224))
model = CNN()
print(model(dummy_data).size())

Thank you very much for your answer @J_Johnson . What if i want to do parameter sharing between the CNN of each image?
Additionally, would it be possible to have a variable number of images. Lets say that some batch could have 10 images, but some others could have 13 images

What do you mean by “parameter sharing”?

If you have a variable number of images, for the “groups” option, you’ll need to pass in zeros where there are no images and set the group to be the maximum you would expect.

For the second option, in the init, you’ll need to define the maximum number of images coming in. In the forward pass, you’ll need to determine which of the models you want activated for each image, possibly by sending in an array into the forward pass and using something like for range(len(image_index)): and replace i with image_index[i].