I’m fixing a inception v3 model for image captioning.
But after 1 epoch of training, i got value error.
How can I fix this problem?
Traceback (most recent call last):
File "train.py", line 107, in <module>
main(args)
File "train.py", line 61, in main
features = encoder(images)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/mango/바탕화면/lab/pytorch-tutorial/tutorials/03-advanced/image_captioning/model.py", line 20, in forward
embed = self.inception(images)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torchvision/models/inception.py", line 199, in forward
x, aux = self._forward(x)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torchvision/models/inception.py", line 169, in _forward
aux = self.AuxLogits(x)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torchvision/models/inception.py", line 419, in forward
x = self.conv1(x)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torchvision/models/inception.py", line 440, in forward
x = self.bn(x)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 136, in forward
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/functional.py", line 2054, in batch_norm
_verify_batch_size(input.size())
File "/home/mango/anaconda3/envs/img/lib/python3.7/site-packages/torch/nn/functional.py", line 2037, in _verify_batch_size
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 768, 1, 1])
class Inception(nn.Module):
def __init__(self, embed_size):
super(Inception, self).__init__()
self.inception = models.inception_v3(pretrained=True)
#in_features = self.inception.fc.in_features
#self.linear = nn.Linear(in_features, embed_size)
self.linear = nn.Linear(self.inception.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
#self.inception.fc = self.linear
def forward(self, images):
embed = self.inception(images)
return embed
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
"""Set the hyper-parameters and build the layers."""
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seg_length = max_seq_length
def forward(self, features, captions, lengths):
"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions)
#print(embeddings.size())
#print(features.logits.unsqueeze(1).size())
embeddings = torch.cat((features.logits.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs
def sample(self, features, states=None):
"""Generate captions for given image features using greedy search."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(self.max_seg_length):
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
_, predicted = outputs.max(1) # predicted: (batch_size)
sampled_ids.append(predicted)
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
return sampled_ids