I see. So, I found a workaround solution, you can put the two inputs in a dictionary. I tried that with a dictionary, but it might be also possible to that with list. Here is a working example:
class BasicBlock(nn.Module):
def __init__(self, in_planes_1, in_planes_2, out_planes):
super(BasicBlock, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes_1, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_planes_2, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
x1, x2 = x[0], x[1]
out1 = self.conv1(self.relu(x1))
out2 = self.conv2(self.relu(x2))
return {0:torch.cat([x1, out1], 1), 1:torch.cat([x2, out2], 1)}
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.main = nn.Sequential(
BasicBlock(3, 1, 8),
BasicBlock(11, 9, 16))
def forward(self, x):
return self.main(x)
Note that I have made the input and output consistent using the dictionary.
Now we create a model:
>>> m = MyModel()
>>> m
MyModel(
(main): Sequential(
(0): BasicBlock(
(relu): ReLU(inplace)
(conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv2): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): BasicBlock(
(relu): ReLU(inplace)
(conv1): Conv2d(11, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv2): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
)
and now we pass the two arguments as follows:
>>> t1 = torch.randn(10, 3, 28, 28)
>>> t2 = torch.randn(10, 1, 28, 28)
>>> output = m({0:t1, 1:t2})
>>> output[0].shape, output[1].shape
>>> type(output)
<class 'dict'>
>>> output.keys()
dict_keys([0, 1])
(torch.Size([10, 27, 28, 28]), torch.Size([10, 25, 28, 28]))