I need help , I don’t understand where did I make mistake to get this error from below network.
RuntimeError: Error(s) in loading state_dict for EncoderCNN: While copying the parameter named "embed.bias", whose dimensions in the model are torch.Size([256]) and whose dimensions in the checkpoint are torch.Size([5]). While copying the parameter named "embed.weight", whose dimensions in the model are torch.Size([256, 2048]) and whose dimensions in the checkpoint are torch.Size([5, 2048]).
my model is ;
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
def forward(self, images):
features = self.resnet(images)
features = Variable(features.data)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=2):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.dropout = nn.Dropout(0.2)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, features, captions):
captions = captions[:,:-1]
embeddings = self.embed(captions)
inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
hiddens, _ = self.lstm(inputs)
outputs = self.linear(hiddens)
return outputs
def sample(self, inputs, states=None, max_len=20):
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
sampled_ids = []
for i in range(max_len):
hiddens, states = self.lstm(inputs, states)
outputs = self.linear(hiddens.squeeze(1))
predicted = outputs.max(1)[1]
sampled_ids.append(predicted.data[0])
inputs = self.embed(predicted)
inputs = inputs.unsqueeze(1)
return sampled_ids