I am working on implementing an image captioning model using an Encoder-Decoder architecture where the Encoder is a pre-trained CNN module (inception_v3) and the Decoder will contain an embedding layer followed by an LSTM layer.
I am using the 2017 COCO dataset for image captioning and while training my model I witnessed weird behavior.
When I created my dataset I added a <PAD> symbol as part of the vocabulary that is used after the <EOS> symbol to pad all tensors to the same length.
After a small number of iteration my model converges to captioning every image as <SOS><PAD><PAD>…<PAD>
I was wondering what could cause such an issue.
<SOS>=Start of sentence symbol
<EOS>=End of sentence symbol
I am attaching the models code along with the training code
import torch
import torch.nn as nn
import torchvision.models as models
device = "cuda" if torch.cuda.is_available() else "cpu"
class EncoderCNN(nn.Module):
def __init__(self, embed_size, train_CNN=False):
super(EncoderCNN, self).__init__()
self.train_CNN = train_CNN
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, images):
features = self.inception(images)
output = self.dropout(self.relu(features))
return output
class DecoderRNN(nn.Module):
"""
Input is a CNN network, output will be a caption.
TODO: Check how to implement a transformer for better results
"""
def __init__(self, embed_size, hidden_size, vocab_size):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm_cell = nn.LSTMCell(embed_size, hidden_size)
self.fc_out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.5)
def forward(self, features, captions):
# batch size
batch_size = features.size(0)
# init the hidden and cell states to zeros
hidden_state = torch.zeros((batch_size, self.hidden_size)).to(device)
cell_state = torch.zeros((batch_size, self.hidden_size)).to(device)
# define the output tensor placeholder
outputs = torch.empty((batch_size, captions.size(1), self.vocab_size)).to(device)
# embed the captions
captions_embed = self.embed(captions)
# tensor of shape (B, LEN, EMBED SIZE)
# LEN- vectors length (longest caption+2)
# pass the caption word by word
for t in range(captions.size(1)):
# for the first time step the input is the feature vector
if t == 0:
hidden_state, cell_state = self.lstm_cell(features, (hidden_state, cell_state))
# for the 2nd+ time step, using teacher forcer
else:
#hidden_state, cell_state = self.lstm_cell()
hidden_state, cell_state = self.lstm_cell(captions_embed[:, t, :], (hidden_state, cell_state))
# output of the attention mechanism
out = self.fc_out(self.dropout(hidden_state))
# build the output tensor
outputs[:, t, :] = out
return outputs
class CNNtoRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, train_CNN=False):
super(CNNtoRNN, self).__init__()
self.encoderCNN = EncoderCNN(embed_size, train_CNN).to(device)
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size).to(device)
def forward(self, images, captions):
features = self.encoderCNN(images)
outputs = self.decoderRNN(features, captions)
return outputs
def caption_images(self, features, vocab, max_len=50):
# Inference part
# Given the image features generate the captions
# input shape: (3,x,y) where, x,y: image size
# ouput: captions list
self.eval()
with torch.no_grad():
image_pred = self.encoderCNN(features)
batch_size = features.size(0)
# init the hidden and cell states to zeros
hidden_state = torch.zeros((batch_size, self.decoderRNN.hidden_size)).to(device)
cell_state = torch.zeros((batch_size, self.decoderRNN.hidden_size)).to(device)
captions_embed = None # embedding of partial caption
#starting input
captions = []
for t in range(max_len):
# for the first time step the input is the feature vector
if t == 0:
hidden_state, cell_state = self.decoderRNN.lstm_cell(image_pred, (hidden_state, cell_state))
# for the 2nd+ time step, use previously generated caption
else:
hidden_state, cell_state = self.decoderRNN.lstm_cell(captions_embed, (hidden_state, cell_state))
# output of the attention mechanism
out = self.decoderRNN.fc_out(self.decoderRNN.dropout(hidden_state))
word_idx = torch.argmax(out).item()
captions.append(word_idx)
if vocab.itos[word_idx] == vocab.stoi["<EOS>"]:
break
captions_embed = self.decoderRNN.embed(torch.argmax(out)).unsqueeze(0)
# build the output tensor
print(captions)
#covert the vocab idx to words and return sentence
self.train()
return [vocab.itos[idx] for idx in captions if idx != vocab.stoi["<PAD>"]]
import torch.optim as optim
from tqdm import tqdm
def train(max_epochs, model):
# Hyperparameters
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
# init model
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
# start epochs
for epoch in range(max_epochs):
for idx, (img, captions) in tqdm(
enumerate(data_loader), total=len(data_loader), leave=False
):
img = img.to(device)
captions = captions.to(device).long()
output = model(img, captions).to(device)
loss = criterion(output.reshape(-1, output.shape[2]), captions.reshape(-1))
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
if idx % 10 == 0 and idx > 0:
dataiter = iter(data_loader)
img_show,cap = next(dataiter)
print(f"\nLoss {loss.item():.5f}\n")
demo_cap = model.caption_images(img_show[0:1].to(device), vocab=dataset.vocab, max_len=30)
demo_cap = ' '.join(demo_cap)
print("Predicted")
cap = cap[0]
print(cap.long())
print("Original")
show_image(img_show[0],title=demo_cap)
demo_cap = ' '.join([dataset.vocab.itos[idx2.item()] for idx2 in cap if idx2.item() != dataset.vocab.stoi["<PAD>"]])
show_image(img_show[0],title=demo_cap, transform=False)
#input("Contiue?")
return model
embed_size = 1024
hidden_size = 512
vocab_size = len(dataset.vocab)
model = CNNtoRNN(embed_size, hidden_size, vocab_size, train_CNN=False)
trained_model = train(3, model)