The error is thrown, since you are wrapping all modules in an nn.Sequential
module, which is missing the flatten operation defined in resnet’s forward
.
You could define a custom Flatten
module and add it right before the last linear layer:
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
x = x.view(x.size(0), -1)
return x
modules = list(model.children())[:3]
resnet_1st = nn.Sequential(*modules)
modules = list(model.children())[3:-1]
resnet_2nd = nn.Sequential(*[*modules, Flatten(), list(model.children())[-1]])
x = torch.randn(1, 3, 224, 224)
out_1st = resnet_1st(x)
print(out_1st.shape)
out_2nd = resnet_2nd(out_1st)
print(out_2nd.shape)