Additional arguments on resblock makes an issue

I’m trying to pass one more input to ResBlocks, and the ResNet I implemented is basically from torchvision ResNet as below. The only difference is that there’s one more input to be forwarded in BasicBlocks.

model = ResNet(BasicBlock, [2, 2, 2, 2])

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        self.layer1 = self._make_layer(block, 64, layers[0])

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    def forward(self, x, alpha):
        x = self.layer1(x, alpha) ## HERE
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)

    def forward(self, x, alpha):
        residual = x

        out = self.conv1(x)
        out = self.addi(out, alpha) ## Here needs 'alpha'
        out = self.relu(out)
        return out

However, this causes an error,

TypeError: forward() takes 2 positional arguments but 3 were given,

or If I replace x=self.layer1(x,alpha) with x=self.layer1(x)

TypeError: forward() missing 1 required positional argument: 'alpha'

How can I fix it? Thanks!

The problem is because nn.Sequential (in ResBlock) takes only one input to be forwarded.

Thus, I solved it by override nn.Sequential function as

class MySequential(nn.Sequential):
    def forward(self, x, alpha):
        for module in self._modules.values():
            x = module(x, alpha)
        return x
1 Like