[pytorch Android] I got the error “Method 'forward' is not defined.”, but I defined 'forward' function in my model

I implemented the resnet18 variant of the network by myself, and defined the’forward’ function in the network. My model can be trained on the computer normally, but when I save the model according to the method written in the pytorch Android tutorial, I got the error “Method’forward’ is not defined.”. I want to know what is the problem.

Here is my code:
model = multi_resnet18_kd(first_channel=1, num_classes=50)
model.eval()
example = torch.rand(1, 1, 200, 68)
traced_script_module = torch.jit.trace(model, example)
print(traced_script_module)
optimized_traced_model = optimize_for_mobile(traced_script_module) # There is an error when running here
optimized_traced_model.save(“model.pt”)

Could you post the model definition so that we could reproduce this issue, please?

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier :wink:

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, 
                        stride=stride, padding=1, bias=False)

def conv1x1(in_planes, planes, stride=1):
    return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)

def branchBottleNeck(channel_in, channel_out, kernel_size):
    middle_channel = channel_out//4
    return nn.Sequential(
        nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
        nn.BatchNorm2d(middle_channel),
        nn.ReLU(),
        
        nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
        nn.BatchNorm2d(middle_channel),
        nn.ReLU(),
        
        nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
        nn.BatchNorm2d(channel_out),
        nn.ReLU(),
        )
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        residual = x

        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)

        output = self.conv2(output)
        output = self.bn2(output)

        if self.downsample is not None:
            residual = self.downsample(x)
        
        output += residual
        output = self.relu(output)
        return output

class BottleneckBlock(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BottleneckBlock, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)

        self.conv3 = conv1x1(planes, planes*self.expansion)
        self.bn3 = nn.BatchNorm2d(planes*self.expansion)

        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        residual = x

        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)

        output = self.conv2(output)
        output = self.bn2(output)
        output = self.relu(output)

        output = self.conv3(output)
        output = self.bn3(output)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        output += residual
        output = self.relu(output)

        return output

class Multi_ResNet(nn.Module):

    def __init__(self, block, layers, first_channel=3, num_classes=1000):
        super(Multi_ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(first_channel, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.downsample1_1 = nn.Sequential(
                            conv1x1(64 * block.expansion, 512 * block.expansion, stride=8),
                            nn.BatchNorm2d(512 * block.expansion),
        )
        self.bottleneck1_1 = branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
        self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
        self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)


        self.downsample2_1 = nn.Sequential(
                            conv1x1(128 * block.expansion, 512 * block.expansion, stride=4),
                            nn.BatchNorm2d(512 * block.expansion),
            )
        self.bottleneck2_1 = branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
        self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
        self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)


        self.downsample3_1 = nn.Sequential(
                            conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
                            nn.BatchNorm2d(512 * block.expansion),
        )
        self.bottleneck3_1 = branchBottleNeck(256 * block.expansion, 512 * block.expansion, kernel_size=2)
        self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
        self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d): 
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def _make_layer(self, block, planes, layers, stride=1):

        downsample = None
        if stride !=1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layer = []
        layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, layers):
            layer.append(block(self.inplanes, planes))
        
        return nn.Sequential(*layer)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def multi_resnet18_kd(first_channel=3, num_classes=1000):
    return Multi_ResNet(BasicBlock, [2,2,2,2], first_channel=first_channel, num_classes=num_classes)

Thanks for the code snippet.
I cannot reproduce the issue using a source build with:

torch.__version__
'1.9.0a0+git6d45d7a'

torchvision.__version__
'0.11.0a0+882e11d'

and the optimization step doesn’t raise an error, so you might want to update to the latest nightly binary (or build from source). In case you are using an older stable release, you might also consider updating to the latest stable one (I haven’t checked it with 1.8.1).

Thank you for your answer, the problem has been solved.