ResNet18 get the output before the last FC (sequential module for FC)

I need some help. Basically I want to get the normal output of the ResNet18 that I modified, but also the one from the nn.Linear(512,32) I am able to get the one before nn.Linear(512,32), but not the one coming out, so the one just before the nn.Linear(32,5). I need this to then use an LSTM with these extracted features.

import torch
from torchvision import models
import torch.nn as nn


class Resnet18(torch.nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        resnet18_pretrained = models.resnet18(pretrained=True)

        self.model = resnet18_pretrained
        self.model.fc =  nn.Sequential(nn.Linear(512, 32),nn.Linear(32, 5))
        self.couches_before_fc = list(self.model.children())[:-1]
        self.resnet_before_fc = nn.Sequential(*self.couches_before_fc)
        self.resnet_before_fc.fc = self.model.fc[0]
        #self.resnet_before_fc = nn.Sequential(*self.couches_before_fc,self.model.fc[0]) #also try this way

    def forward(self, x):

        before_last_fc = self.resnet_before_fc(x)

        x=self.model(x)

        return x, before_last_fc

This strategy seems logic in my mind but I got the following error:
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc) related to the line before_last_fc = self.resnet_before_fc(x).

By using a hook I am able to get the value of the output before the nn.Linear(32, 5) but I am not able to get it in a tensor form and return it at the end of the forward…

self.model.fc[0].register_forward_hook(lambda m, input, output: print(output))

also this additional line make the code crashing after trying to save the model:
torch.save(model, model_out_path)
AttributeError: Can’t pickle local object ‘Resnet18.forward..’

Thanks a lot!

If I execute your code I get a shape mismatch error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1 and 512x32)

since you are not flattening the activation.
This should work:

class Resnet18(torch.nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        resnet18_pretrained = models.resnet18(pretrained=True)

        self.model = resnet18_pretrained
        self.model.fc =  nn.Sequential(nn.Linear(512, 32),nn.Linear(32, 5))
        self.couches_before_fc = list(self.model.children())[:-1]
        self.resnet_before_fc = nn.Sequential(*self.couches_before_fc)
        self.resnet_before_fc.fc = nn.Sequential(
            nn.Flatten(),
            self.model.fc[0]
        )

    def forward(self, x):

        before_last_fc = self.resnet_before_fc(x)

        x=self.model(x)

        return x, before_last_fc
    
model = Resnet18()
x = torch.randn(2, 3, 224, 224)
out = model(x)

The error is created, since you are wrapping all child modules into an nn.Sequential container, which will be missing all functional API calls from the original forward method (and thus also the flattening operation), so make sure to verify that the nn.Sequential container is indeed calling all needed operations/modules.

1 Like

Thanks a lot this solution is working !