Aim: To extract the intermediate features (say from block_num=4) from a pretrained resnet50. Then feed the same extracted feature to the next blocks (In between I have to do some manipulation, but I have not done it here). Even the shape of the block feature is correct that is required for the next layer. I am unable to solve this issue. Kindly help.
import torch
import torch.nn as nn
from torchvision import models
model = models.resnet50(pretrained=True)
model.eval()
inputs = torch.randn(2, 3, 224, 224)
block_num = 4
feature_extractor = nn.Sequential(*list(model.children())[:block_num + 1])
rest_of_network = nn.Sequential(*list(model.children())[block_num+1 :])
block_features = feature_extractor(inputs)
print(block_features.shape)
perturbed_block_features = rest_of_network(block_features) → ERROR
terminal output:
torch.Size([2, 256, 56, 56])
Traceback (most recent call last):
File “trial.py”, line 22, in
perturbed_block_features = rest_of_network(block_features)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/container.py”, line 139, in forward
input = module(input)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 96, in forward
return F.linear(input, self.weight, self.bias)
File “/home/prafful/scratch/conda_envs/MRG/lib/python3.8/site-packages/torch/nn/functional.py”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x1 and 2048x1000)
when run using cuda, I got this error → RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling
cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)
Please help me to solve this issue. I am not able to understand why this simple code is not working.
Best Regards,
Prafful