TypeError: WideResNet.forward() missing 1 required positional argument: 'x'

I can train Resnet-18, Resnet-34, and Resnet-50, but I cannot initialize WideResnet-28-10 without error.

# https://github.com/AlexandrosFerles/NIPS_2019_Reproducibilty_Challenge_Zero-shot_Knowledge_Transfer_via_Adversarial_Belief_Matching/blob/master/src/PyTorch/WideResNet.py
import torch
import torch.nn as nn
# from torchsummary import summary
# import math


class IndividualBlock1(nn.Module):

    def __init__(self, input_features, output_features, stride:int, subsample_input=True, increase_filters=True):
        super(IndividualBlock1, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, bias=False)

        self.subsample_input = subsample_input
        self.increase_filters = increase_filters
        if subsample_input:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=2, padding=0, bias=False)
        elif increase_filters:
            self.conv_inp = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):

        if self.subsample_input or self.increase_filters:
            x = self.batch_norm1(x)
            x = self.activation(x)
            x1 = self.conv1(x)
        else:
            x1 = self.batch_norm1(x)
            x1 = self.activation(x1)
            x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        x1 = self.conv2(x1)

        if self.subsample_input or self.increase_filters:
            return self.conv_inp(x) + x1
        else:
            return x + x1


class IndividualBlockN(nn.Module):

    def __init__(self, input_features, output_features, stride:int):
        super(IndividualBlockN, self).__init__()

        self.activation = nn.ReLU(inplace=True)

        self.batch_norm1 = nn.BatchNorm2d(input_features)
        self.batch_norm2 = nn.BatchNorm2d(output_features)

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, stride=stride, padding=1, bias=False)

    def forward(self, x):
        x1 = self.batch_norm1(x)
        x1 = self.activation(x1)
        x1 = self.conv1(x1)
        x1 = self.batch_norm2(x1)
        x1 = self.activation(x1)
        # x1 = self.conv2(x1)
        return self.conv2(x1) + x


class Nblock(nn.Module):

    def __init__(self, N:int, input_features, output_features, stride:int, subsample_input=True, increase_filters=True):
        super(Nblock, self).__init__()

        layers = []
        for i in range(N):
            if i == 0:
                layers.append(IndividualBlock1(input_features, output_features, stride, subsample_input, increase_filters))
            else:
                layers.append(IndividualBlockN(output_features, output_features, stride=1))

        self.nblockLayer = nn.Sequential(*layers)

    def forward(self, x):
        return self.nblockLayer(x)


class WideResNet(nn.Module):

    def __init__(self, d, k, n_classes, input_features, output_features, strides):
        super(WideResNet, self).__init__()

        self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, stride=strides[0], padding=1, bias=False)

        filters = [16 * k, 32 * k, 64 * k]
        self.out_filters = filters[-1]
        N:int = (d - 4) // 6
        increase_filters:bool = k > 1
        
        self.block1 = Nblock(N, input_features=output_features, output_features=filters[0], stride=strides[1], subsample_input=False, increase_filters=increase_filters)
        self.block2 = Nblock(N, input_features=filters[0], output_features=filters[1], stride=strides[2])
        self.block3 = Nblock(N, input_features=filters[1], output_features=filters[2], stride=strides[3])

        self.batch_norm = nn.BatchNorm2d(filters[-1])
        self.activation = nn.ReLU(inplace=True)
        self.avg_pool = nn.AvgPool2d(kernel_size=8)
        self.fc = nn.Linear(filters[-1], n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**0.5) # math.sqrt
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x): # fails on init
        print("# Valid forward call.")
        x = self.conv1(x)
        attention1 = self.block1(x)
        attention2 = self.block2(attention1)
        attention3 = self.block3(attention2)
        out = self.batch_norm(attention3)
        out = self.activation(out)
        out = self.avg_pool(out)
        out = out.view(-1, self.out_filters)

        return self.fc(out), attention1, attention2, attention3


if __name__=='__main__':

    # change d and k if you want to check a model other than WRN-40-2, here WRN-28-10
    net = WideResNet(d=28, k=10, n_classes=10, input_features=3, output_features=16, strides=[1, 1, 2, 2])

    # verify that an output is produced
    if True:
        sample_input = torch.ones(size=(1, 3, 32, 32), requires_grad=False)
        net(sample_input)  # forward without crash.

    hyperparams = ()
    print("## Next line fails: net( *hp ).state_dict()")
    initilizations = [net(*hyperparams).state_dict()] # can we skip this? if so: why needed for Resnet-nonwide?

# relevant output:
# File "Python311\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
#    return forward_call(*args, **kwargs)
# TypeError: WideResNet.forward() missing 1 required positional argument: 'x'

# Python 3.11.7 + Windows 10 + torch 2.2.0.dev20231122+cu121 (or torch-2.3.0.dev20240109+cu121)
# same code works fine for Resnet-18 + Resnet-34 + Resnet-50
# I tried 3 different public repos on github.com for Wide-Resnet-28-10 (all fail in the same way)

    # Summarize model
    # summary(net, input_size=(3, 32, 32))
#EOF.

System details can be seen at the end of sample code above. I already spend hours on this and would be very happy about any help. Thanks.

OK, solved. “net” needs to be function [ without () ].

def WRN28(): return ResNetWide(param=123)

net = WRN28

Could you please add an assert( callable(model_class) ) somewhere, as torch\nn\modules\module.py tries do do something with a “function” but gets “function()” and does not realize it, but giving an error that most people do not understand. Please.