A network with multiplie inputs seems to get more arguments than it truly gets

Hi,

I created an Inception-V3 model with two inputs by duplicating the forward pass for the second input. When forward pass occurs there’s an error that says that it gets 3 arguments instead of 2 (the 2 are the inputs).
I did a similar thing to a simple conv-net and it worked well. Any idea why it happens and how to fix that?

Note: it’s suffice to look at Inception3_Mult() (mainly its forward pass), and the Main part (at the bottom)

from collections import namedtuple
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Callable, Any, Optional, Tuple, List

class Inception3_Mult(nn.Module):

def __init__(
    self,
    num_classes: int = 1000,
    aux_logits: bool = False,
    transform_input: bool = False,
    inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
    init_weights: Optional[bool] = None
) -> None:
    super(Inception3_Mult, self).__init__()
    if inception_blocks is None:
        inception_blocks = [
            BasicConv2d, InceptionA, InceptionB, InceptionC,
            InceptionD, InceptionE, InceptionAux
        ]
    if init_weights is None:
        warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
                      'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                      ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
        init_weights = True
    assert len(inception_blocks) == 7
    conv_block = inception_blocks[0]
    inception_a = inception_blocks[1]
    inception_b = inception_blocks[2]
    inception_c = inception_blocks[3]
    inception_d = inception_blocks[4]
    inception_e = inception_blocks[5]
    inception_aux = inception_blocks[6]

    self.aux_logits = aux_logits
    self.transform_input = transform_input
    self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
    self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
    self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
    self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
    self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
    self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
    self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
    self.Mixed_5b = inception_a(192, pool_features=32)
    self.Mixed_5c = inception_a(256, pool_features=64)
    self.Mixed_5d = inception_a(288, pool_features=64)
    self.Mixed_6a = inception_b(288)
    self.Mixed_6b = inception_c(768, channels_7x7=128)
    self.Mixed_6c = inception_c(768, channels_7x7=160)
    self.Mixed_6d = inception_c(768, channels_7x7=160)
    self.Mixed_6e = inception_c(768, channels_7x7=192)
    self.AuxLogits: Optional[nn.Module] = None
    if aux_logits:
        self.AuxLogits = inception_aux(768, num_classes)
    self.Mixed_7a = inception_d(768)
    self.Mixed_7b = inception_e(1280)
    self.Mixed_7c = inception_e(2048)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.dropout = nn.Dropout()
    self.fc = nn.Linear(2048, num_classes)
    if init_weights:
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

def _transform_input(self, x: Tensor) -> Tensor:
    if self.transform_input:
        x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
        x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
        x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
        x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    return x

def _forward(self, x: Tensor, x2: Tensor):
    # N x 3 x 299 x 299
    x = self.Conv2d_1a_3x3(x)
    # N x 32 x 149 x 149
    x = self.Conv2d_2a_3x3(x)
    # N x 32 x 147 x 147
    x = self.Conv2d_2b_3x3(x)
    # N x 64 x 147 x 147
    x = self.maxpool1(x)
    # N x 64 x 73 x 73
    x = self.Conv2d_3b_1x1(x)
    # N x 80 x 73 x 73
    x = self.Conv2d_4a_3x3(x)
    # N x 192 x 71 x 71
    x = self.maxpool2(x)
    # N x 192 x 35 x 35
    x = self.Mixed_5b(x)
    # N x 256 x 35 x 35
    x = self.Mixed_5c(x)
    # N x 288 x 35 x 35
    x = self.Mixed_5d(x)
    # N x 288 x 35 x 35
    x = self.Mixed_6a(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6b(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6c(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6d(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6e(x)
    # N x 768 x 17 x 17
    aux: Optional[Tensor] = None
    if self.AuxLogits is not None:
        if self.training:
            aux = self.AuxLogits(x)
    # N x 768 x 17 x 17
    x = self.Mixed_7a(x)
    # N x 1280 x 8 x 8
    x = self.Mixed_7b(x)
    # N x 2048 x 8 x 8
    x = self.Mixed_7c(x)
    # N x 2048 x 8 x 8
    feature1 = x
    # Adaptive average pooling
    x = self.avgpool(x)
    # N x 2048 x 1 x 1
    x = self.dropout(x)
    # N x 2048 x 1 x 1
    x = torch.flatten(x, 1)
    feature2 = x
    # N x 2048
    x = self.fc(x)
    # N x 1000 (num_classes)
    ################

    # N x 3 x 299 x 299
    x2 = self.Conv2d_1a_3x3(x2)
    # N x 32 x 149 x 149
    x2 = self.Conv2d_2a_3x3(x2)
    # N x 32 x 147 x 147
    x2 = self.Conv2d_2b_3x3(x2)
    # N x 64 x 147 x 147
    x2 = self.maxpool1(x2)
    # N x 64 x 73 x 73
    x2 = self.Conv2d_3b_1x1(x2)
    # N x 80 x 73 x 73
    x2 = self.Conv2d_4a_3x3(x2)
    # N x 192 x 71 x 71
    x2 = self.maxpool2(x2)
    # N x 192 x 35 x 35
    x2 = self.Mixed_5b(x2)
    # N x 256 x 35 x 35
    x2 = self.Mixed_5c(x2)
    # N x 288 x 35 x 35
    x2 = self.Mixed_5d(x2)
    # N x 288 x 35 x 35
    x2 = self.Mixed_6a(x2)
    # N x 768 x 17 x 17
    x2 = self.Mixed_6b(x2)
    # N x 768 x 17 x 17
    x2 = self.Mixed_6c(x2)
    # N x 768 x 17 x 17
    x2 = self.Mixed_6d(x2)
    # N x 768 x 17 x 17
    x2 = self.Mixed_6e(x2)
    # N x 768 x 17 x 17
    aux2: Optional[Tensor] = None
    if self.AuxLogits is not None:
        if self.training:
            aux2 = self.AuxLogits(x2)
    # N x 768 x 17 x 17
    x2 = self.Mixed_7a(x2)
    # N x 1280 x 8 x 8
    x2 = self.Mixed_7b(x2)
    # N x 2048 x 8 x 8
    x2 = self.Mixed_7c(x2)
    # N x 2048 x 8 x 8
    feature1_2 = x2
    # Adaptive average pooling
    x2 = self.avgpool(x2)
    # N x 2048 x 1 x 1
    x2 = self.dropout(x2)
    # N x 2048 x 1 x 1
    x2 = torch.flatten(x2, 1)
    feature2_2 = x2
    # N x 2048
    x2 = self.fc(x2)
    # N x 1000 (num_classes)   
    
    print('feature1_2.size(): ', feature1_2.size())
    print('feature2_2.size(): ', feature2_2.size())
    
    return x, feature1, feature2, aux, x2, feature1_2, feature2_2, aux2

@torch.jit.unused
def eager_outputs(self, x: Tensor, feature1: Tensor, feature2: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
    if self.training and self.aux_logits:
        return InceptionOutputs(x, feature1, feature2, aux)
    else:
        return x, feature1, feature2  # type: ignore[return-value]

def forward(self, x: Tensor) -> InceptionOutputs:
    x = self._transform_input(x)
    x, feature1, feature2, aux = self._forward(x)
    aux_defined = self.training and self.aux_logits
    if torch.jit.is_scripting():
        if not aux_defined:
            warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
        return InceptionOutputs(x, feature1, feature2, aux)
    else:
        return self.eager_outputs(x, feature1, feature2, aux)

class InceptionA(nn.Module):

def __init__(
    self,
    in_channels: int,
    pool_features: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionA, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

    self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
    self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

    self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

    self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)

def _forward(self, x: Tensor) -> List[Tensor]:
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x: Tensor) -> Tensor:
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionB(nn.Module):

def __init__(
    self,
    in_channels: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionB, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)

    self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)

def _forward(self, x: Tensor) -> List[Tensor]:
    branch3x3 = self.branch3x3(x)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

    outputs = [branch3x3, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x: Tensor) -> Tensor:
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionC(nn.Module):

def __init__(
    self,
    in_channels: int,
    channels_7x7: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionC, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)

    c7 = channels_7x7
    self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
    self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))

    self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
    self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))

    self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

def _forward(self, x: Tensor) -> List[Tensor]:
    branch1x1 = self.branch1x1(x)

    branch7x7 = self.branch7x7_1(x)
    branch7x7 = self.branch7x7_2(branch7x7)
    branch7x7 = self.branch7x7_3(branch7x7)

    branch7x7dbl = self.branch7x7dbl_1(x)
    branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
    return outputs

def forward(self, x: Tensor) -> Tensor:
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionD(nn.Module):

def __init__(
    self,
    in_channels: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionD, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
    self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)

    self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
    self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)

def _forward(self, x: Tensor) -> List[Tensor]:
    branch3x3 = self.branch3x3_1(x)
    branch3x3 = self.branch3x3_2(branch3x3)

    branch7x7x3 = self.branch7x7x3_1(x)
    branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
    branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
    branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
    outputs = [branch3x3, branch7x7x3, branch_pool]
    return outputs

def forward(self, x: Tensor) -> Tensor:
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionE(nn.Module):

def __init__(
    self,
    in_channels: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionE, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)

    self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
    self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
    self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

    self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
    self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
    self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

    self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

def _forward(self, x: Tensor) -> List[Tensor]:
    branch1x1 = self.branch1x1(x)

    branch3x3 = self.branch3x3_1(x)
    branch3x3 = [
        self.branch3x3_2a(branch3x3),
        self.branch3x3_2b(branch3x3),
    ]
    branch3x3 = torch.cat(branch3x3, 1)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = [
        self.branch3x3dbl_3a(branch3x3dbl),
        self.branch3x3dbl_3b(branch3x3dbl),
    ]
    branch3x3dbl = torch.cat(branch3x3dbl, 1)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x: Tensor) -> Tensor:
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionAux(nn.Module):

def __init__(
    self,
    in_channels: int,
    num_classes: int,
    conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
    super(InceptionAux, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.conv0 = conv_block(in_channels, 128, kernel_size=1)
    self.conv1 = conv_block(128, 768, kernel_size=5)
    self.conv1.stddev = 0.01  # type: ignore[assignment]
    self.fc = nn.Linear(768, num_classes)
    self.fc.stddev = 0.001  # type: ignore[assignment]

def forward(self, x: Tensor) -> Tensor:
    # N x 768 x 17 x 17
    x = F.avg_pool2d(x, kernel_size=5, stride=3)
    # N x 768 x 5 x 5
    x = self.conv0(x)
    # N x 128 x 5 x 5
    x = self.conv1(x)
    # N x 768 x 1 x 1
    # Adaptive average pooling
    x = F.adaptive_avg_pool2d(x, (1, 1))
    # N x 768 x 1 x 1
    x = torch.flatten(x, 1)
    # N x 768
    x = self.fc(x)
    # N x 1000
    return x

class BasicConv2d(nn.Module):

def __init__(
    self,
    in_channels: int,
    out_channels: int,
    **kwargs: Any
) -> None:
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
    self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x: Tensor) -> Tensor:
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x, inplace=True)

####### Main
model2 = Inception3_Mult()
inputs = torch.rand(16,3,400,400) # random tensor
x, feature1, feature2, x2, feature1_2, feature2_2 = model2(inputs, inputs)

Error:
TypeError: forward() takes 2 positional arguments but 3 were given

When it gets only 1 input:
x, feature1, feature2, x2, feature1_2, feature2_2 = model2(inputs)

It says (as expected):
TypeError: _forward() missing 1 required positional argument: ‘x2’

What I think is happening is that in your forward method for the Inception3_Mult module, there is only one input defined.

def forward(self, x: Tensor) -> InceptionOutputs:

And inside this forward pass you call your _forward method, which takes 2 inputs.

def _forward(self, x: Tensor, x2: Tensor):

But you only pass the original one passed to the forward

x, feature1, feature2, aux = self._forward(x)

So, you need to either

  • Put 2 inputs in your forward, or
  • Expect only one input in your _forward, or
  • Pass the same x twice to your _forward

Hope this helps :smile:

1 Like

Yes, you’re totally right!

I addressed only to the _forward and ignored the forward :slight_smile:
Now I understand the problem.
I deleted the forward and changed “_forward” to “forward” and now it works well. Thanks!

Just out of curiosity, any idea why did they use both the regular “forward” and “_forwads”?

1 Like

I’m guessing they did this for code clarity, and having the least amount of code in the actual forward pass by moving all of the architecture to this other method.

But I might be wrong :smile:

1 Like