LSTM only learning with batch_size=1

Hi,

I am currently trying to set up an LSTM for image captioning. I am using image features for the first initial hidden state and then produce an output sequence of words. As an initial sanity test, I tried to overfit on 50 images. This works very well if using batch_size=1. However for larger batch sizes, it doesn’t really converge to anything good and will end up predicting the same word, usually the most frequent first word in the training set. This is my model:

class CaptioningLSTM(nn.Module):
  def __init__(self, embed_size, hidden_size, voc_dim, feat_dim, null, end, start, n_l=1):
    # Size of word embeddings, hidden state size, vocable size, number of layers
    super(CaptioningLSTM, self).__init__()
    self.embed_size = embed_size
    self.hidden_size = hidden_size
    self.voc_dim = voc_dim
    self.feat_dim = feat_dim
    self.n_l = n_l
    self.word_embed = nn.Embedding(self.voc_dim, self.embed_size)
    #self.lstm = nn.LSTM(input_size = self.embed_size, hidden_size = self.hidden_size,
     #                   num_layers = self.n_l,
      #                  batch_first=True)
    self.lstm = nn.LSTMCell(input_size = self.embed_size, hidden_size=self.hidden_size)
    self.null = null
    self.end = end
    self.start = start
    self.dropout = nn.Dropout(0.2)
    self.lin_out = nn.Linear(self.hidden_size, self.voc_dim)
    self.lin_image = nn.Linear(self.feat_dim, self.hidden_size)
    self.word_embed.weight.data.uniform_(-0.1, 0.1)
    self.lin_out.bias.data.fill_(0)
    self.lin_out.weight.data.uniform_(-0.1, 0.1)
    self.lin_image.bias.data.fill_(0)
    self.lin_image.weight.data.uniform_(-0.1, 0.1)
    self.relu = nn.ReLU()
    
  
  def initialize_h(self, batch_size):
    return( torch.zeros(batch_size, self.hidden_size).to(device),
           torch.zeros(batch_size, self.hidden_size).to(device))
  
  def forward(self, features, captions):
    captions = captions[:,:-1]
    self.batch_size = features.shape[0]
    h, c = self.initialize_h(self.batch_size)
    h = self.relu(self.lin_image(features.float()))
    outputs = (torch.ones((self.batch_size, captions.size(1), self.voc_dim))*self.null).to(device)
    for t in range(captions.size(1)):
        h, c = self.lstm(self.word_embed(captions[:, t]), (h,c))
        out = self.lin_out(h)
        outputs[:, t, :] = out
    return outputs

An this is the training loop:

for epoch in range(num_epochs):
    caps, feats, u_s = get_all_batches(data, batch_size)
    for it_batch, (captions, features, urls) in enumerate(zip(caps, feats, u_s)):
      # Compute loss and gradient
      scores = model(features, captions)
      loss = loss_fn(scores.transpose(1,2), captions[:,1:])
      print(scores.transpose(1,2).shape, captions[:, 1:].shape)
      optim.zero_grad()
      loss.backward()
      optim.step()

The captions are of shape (batch_size, seq_length) and the network’s outputs are of shape (batch_size, sequence length, vocabulary_dim), hence the transpose() leading to the loss’ input being
torch.Size([32, 1004, 16]), torch.Size([32, 16]).

Any help is greatly appreciated!