Multiple input model architecture


I’m really new in machine learning and I’d like to have some advices.
My problem is the following:
I’ve 2 images (1st is 256x256 and the second 64x64) and some data (list of 10 floats) as an input and I’d like to classify the data in 4 classes (for now).

Is there a way to organize my data to fit it in a standard model with one input?
Is there a way to architecture my model to fit in each item separately?

Thank you

The main question for the first approach would be, how to stack your different inputs together?
The images might be resized to a common size and stacked together, but how would you handle the list data?

You could easily create a model architecture, where you have different paths for each input and concatenate them together at some point. Assuming your images have one single channel, here is a small example:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
        self.features2 = nn.Sequential(
            nn.Conv2d(1, 3, 3, 1, 1),
        self.features3 = nn.Sequential(
            nn.Linear(10, 5),

        self.classifier = nn.Linear(128*128*3 + 32*32*3 + 5, 4)
    def forward(self, x1, x2, x3):
        x1 = self.features1(x1)
        x2 = self.features2(x2)
        x3 = self.features3(x3)

        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)
        x3 = x3.view(x3.size(0), -1)
        x =, x2, x3), dim=1)
        x = self.classifier(x)
        return x

model = MyModel()
batch_size = 1
x1 = torch.randn(batch_size, 1, 256, 256)
x2 = torch.randn(batch_size, 1, 64, 64)
x3 = torch.randn(batch_size, 10)

output = model(x1, x2, x3)

Hi, I have a simple question.
After I implement such a multi-input model, with 3 inputs and 1 output.
How should I do the loss.backward() operation?
Is it the same as usual if all the parameters’ requires_grad = True?

Since you are only dealing with a single output, you can calculate the loss as usual and just call loss.backward() to compute all gradients.

1 Like

Thank you, that sounds good. I will try that ~

Hi, I’m having a similar situation over here where I’m feeding three images into a network where they take different paths but get ultimately merged before the final layer. My question is what changes do I need to make in the dataloader and the scheduler to accommodate for this ? (Feeding three images into the network at once)

Hello, I’m a bit confused as to how I could do this in a segmentation problem. At which point in the segmentation network can I concatenate the scalar list?

It’s hard to give a general answer, as it’s unclear what would work the best.
Could you explain your use case a bit and what the scalar input represents?

I have a 3D image and scalars that i want to give as input to the network. The image is a medical scan and the list of scalars represent clinical values. I would like to input both the image and the scalars to my segmentation network (UNet-like).

My concern is where exactly do I concatenate both inputs.

Should I concatenate the scalars to the image in the bottleneck of the network by flattening the image and then proceding to the decoder part via a linear layer? Is there a better approach?

Thank you in advance

I’m really unsure, what might work the best.
Your bottleneck approach sounds reasonable and I would give it a try.

1 Like

Hi, Thanks for this explanation.

I have one question.
Can we take maximum of all three features X1,X2 and X3 Instead of concatenate them together. If yes than how can we find maximum of all three features

If it is not correct way to take maximum please let me know the correct way to take maximum of all three features.


Would you like to get the elementwise maximum of these features, i.e. the result might contain different features from all input tensors or would you like to somehow use a scalar value and select only a single input tensor?

The first use case should work using torch.max two times, as it will only accept a pair of tensors.

Hello, I have a related question. I have a network that takes two inputs. But there I am some cases that (only in evaluation) that I need to only pass one of them. The idea is that i don’t want to pass random/zeros to save the processing time?

If you are not using the second input during evaluation, you could pass a None value to them.
Since these additional inputs are not used, you shouldn’t get any errors.

Hi thanks for this discussion, I managed to make a multi-input one.
Now the question is, Instead of concatenating it with, how about if I want to add some weight/importance like input 1 is more important the input 2 (75:25)?