nn.Sequential(*layers) forward: with multiple inputs Error

I see. So, I found a workaround solution, you can put the two inputs in a dictionary. I tried that with a dictionary, but it might be also possible to that with list. Here is a working example:

class BasicBlock(nn.Module):
    def __init__(self, in_planes_1, in_planes_2, out_planes):
        super(BasicBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes_1, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_planes_2, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
    def forward(self, x):
        x1, x2 = x[0], x[1]
        out1 = self.conv1(self.relu(x1))
        out2 = self.conv2(self.relu(x2))
        return {0:torch.cat([x1, out1], 1), 1:torch.cat([x2, out2], 1)}

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.main = nn.Sequential(
            BasicBlock(3, 1, 8),
            BasicBlock(11, 9, 16))
    def forward(self, x):
        return self.main(x)

Note that I have made the input and output consistent using the dictionary.

Now we create a model:

>>> m = MyModel()
>>> m
MyModel(
  (main): Sequential(
    (0): BasicBlock(
      (relu): ReLU(inplace)
      (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): BasicBlock(
      (relu): ReLU(inplace)
      (conv1): Conv2d(11, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
)

and now we pass the two arguments as follows:

>>> t1 = torch.randn(10, 3, 28, 28)
>>> t2 = torch.randn(10, 1, 28, 28)
>>> output = m({0:t1, 1:t2})
>>> output[0].shape, output[1].shape
>>> type(output)
<class 'dict'>
>>> output.keys()
dict_keys([0, 1])
(torch.Size([10, 27, 28, 28]), torch.Size([10, 25, 28, 28]))
5 Likes