Bifurcation of a pretrained model and error in blockwise forward pass

Aim: To extract the intermediate features (say from block_num=4) from a pretrained resnet50. Then feed the same extracted feature to the next blocks (In between I have to do some manipulation, but I have not done it here). Even the shape of the block feature is correct that is required for the next layer. I am unable to solve this issue. Kindly help.

import torch
import torch.nn as nn

from torchvision import models

model = models.resnet50(pretrained=True)


inputs = torch.randn(2, 3, 224, 224)

block_num = 4

feature_extractor = nn.Sequential(*list(model.children())[:block_num + 1])
rest_of_network = nn.Sequential(*list(model.children())[block_num+1 :])
block_features = feature_extractor(inputs)

perturbed_block_features = rest_of_network(block_features) → ERROR

terminal output:

torch.Size([2, 256, 56, 56])
Traceback (most recent call last):
File “”, line 22, in
perturbed_block_features = rest_of_network(block_features)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/”, line 139, in forward
input = module(input)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/”, line 96, in forward
return F.linear(input, self.weight, self.bias)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x1 and 2048x1000)

when run using cuda, I got this error → RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling
cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)

Please help me to solve this issue. I am not able to understand why this simple code is not working.

Best Regards,


Wrapping submodules into an nn/Sequential container will often fail as you would be missing all functional API calls. For resnet50 you would miss this torch.flatten operation which is most likely causing the issue.
Either add an nn.Flatten module manually to your nn.Sequential container or derive a custom model and override the forward method with your custom approach.

Thank you, it works fine.