Break resnet into two parts

The error is thrown, since you are wrapping all modules in an nn.Sequential module, which is missing the flatten operation defined in resnet’s forward.
You could define a custom Flatten module and add it right before the last linear layer:

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x

modules = list(model.children())[:3]
resnet_1st = nn.Sequential(*modules)

modules = list(model.children())[3:-1]
resnet_2nd = nn.Sequential(*[*modules, Flatten(), list(model.children())[-1]])

x = torch.randn(1, 3, 224, 224)
out_1st = resnet_1st(x)
print(out_1st.shape)
out_2nd = resnet_2nd(out_1st)
print(out_2nd.shape)
8 Likes