Converge problem of a Region Caption Model

Hello,

I’m working on a region captioning task. I’m not sure what mistakes I might have done, as I see that the loss does not decrease forever. Could somebody give me a short review of my model. Thanks in advance.

import math

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from .config import cfg
from os.path import join as opj
import json

# Language Criterion
class TemporalCrossEntropyCriterion(nn.Module):
    def __init__(self):
        super(TemporalCrossEntropyCriterion, self).__init__()
        self.lsm = nn.LogSoftmax(dim=1)
        
        # Whether to average over space and batch
        self.batch_average = True
        self.time_average = False
    
    def forward(self, input, target):
        # print("input.shape", input.shape)
        N, T, C = input.shape
        assert (target.dim() == 2 and target.size(0) == N and target.size(1) == T)
        null_mask = torch.eq(target, 0)
        target[null_mask] = 1

        ############time step case############
        # null_mask = torch.eq(target, 0)
        # target = target[:, :input.size(1)]
        # null_mask = null_mask[:, :input.size(1)]
        ############time step case############
        
        logprobs = self.lsm.forward(input)
        #losses = torch.gather(logprobs, 2, target.unsqueeze(-1)).mul(-1).squeeze(-1)
        losses = torch.gather(logprobs, 2, target.clone().view(N, T, -1)).mul(-1).squeeze(-1)
        losses[null_mask] = 0
        if self.batch_average:
            losses = torch.sum(losses) / N  # torch.sum(null_mask)
        if self.time_average:
            losses = torch.sum(losses) / T
        target[null_mask] = 0
        return losses

# Caption Model
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, device="cuda"):
        super(LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        D, H = input_dim, hidden_dim
        
        self.weight = nn.Parameter(torch.randn(D + H, 4 * H))
        self.bias = nn.Parameter(torch.randn(4 * H))
        
        self.hidden = torch.empty(0).to(device)
        self.cell = torch.empty(0).to(device)
        self.gates = torch.empty(0).to(device)
        
        self.h0 = torch.empty(0).to(device)
        self.c0 = torch.empty(0).to(device)
        
        self.remember_states = False
        self.device = device
    
    def reset(self, std):
        if not std:
            std = 1.0 / math.sqrt(self.hidden_dim + self.input_dim)
        
        self.bias.zero_()
        self.bias[self.hidden_dim, 2 * self.hidden_dim].fill_(1)
        self.weight.normal_(0, std)
    
    def reset_states(self):
        self.h0 = torch.empty(0).to(self.device)
        self.c0 = torch.empty(0).to(self.device)
    
    def check_dims(self, x, dims):
        assert (x.dim() == len(dims))
        for i, d in enumerate(dims):
            assert x.size(i) == d
    
    def unpack_input(self, input):
        c0, h0, x = None, None, None
        if type(input) == tuple and len(tuple) == 3:
            c0, h0, x = input
        elif type(input) == tuple and len(tuple) == 2:
            h0, x = input
        elif torch.is_tensor(input):
            x = input
        else:
            assert False, "invalid input"
        
        return c0, h0, x
    
    def get_sizes(self, input):
        c0, h0, x = self.unpack_input(input)
        N, T = x.size(0), x.size(1)
        H, D = self.hidden_dim, self.input_dim
        self.check_dims(x, [N, T, D])
        if h0:
            self.check_dims(h0, [N, H])
        if c0:
            self.check_dims(c0, [N, H])
        
        return N, T, D, H
    
    def lstm_step_forward(self, x, prev_h, prev_c, Wx, Wh, b):
        next_h, next_c = None, None

        H = prev_h.shape[1]
        
        gates = torch.addmm(b, x, Wx)

        gates.addmm(prev_h, Wh)
        
        # slice gate vector
        gate_i = gates[:, 0:H]
        gate_f = gates[:, H:2 * H]
        gate_o = gates[:, 2 * H:3 * H]
        gate_g = gates[:, 3 * H:]
        
        # activation functions applied to our 4 gates.
        input_gate = torch.sigmoid(gate_i)
        forget_gate = torch.sigmoid(gate_f)
        output_gate = torch.sigmoid(gate_o)
        block_input = torch.tanh(gate_g)
        
        # calculate next cell state
        next_c = (forget_gate * prev_c) + (input_gate * block_input)
        
        # calculate next hidden state
        next_h = output_gate * torch.tanh(next_c)
        
        return next_h, next_c
    
    def forward(self, input):
        c0, h0, x = self.unpack_input(input)
        N, T, D, H = self.get_sizes(input)
        
        if not c0:
            c0 = self.c0
            if c0.numel() == 0 or not self.remember_states:
                # c0.resize_(N, H).zero_().to(self.device)
                self.c0 = Variable(torch.zeros(N, H)).to(self.device)
                c0 = self.c0
            elif self.remember_states:
                # c0 = self.c0
                prev_N, prev_T = self.cell.shape
                assert prev_N == N, 'batch sizes must be constant to remember states'
                c0 = self.cell[:, :prev_T]
        
        if not h0:
            h0 = self.h0
            if h0.numel() == 0 or not self.remember_states:
                self.h0 = Variable(torch.zeros(N, H)).to(self.device)
                h0 = self.h0
            elif self.remember_states:
                prev_N, prev_T = self.hidden.shape
                assert prev_N == N, 'batch sizes must be the same to remember states'
                h0 = self.hidden[:, :prev_T]
        
        bias_expand = self.bias.view(1, 4 * H).expand(N, 4 * H).to(self.device)
        
        Wx = self.weight[:D, :]
        
        Wh = self.weight[D:D + H, :]
        
        self.hidden = Variable(torch.zeros(N, T, H).to(self.device))
        self.cell = Variable(torch.zeros(N, T, H).to(self.device))
        # h, c = self.output, self.cell
        
        prev_h, prev_c = h0, c0
        
        for t in range(T):
            prev_h, prev_c = self.lstm_step_forward(x[:, t, :], prev_h, prev_c, Wx, Wh, bias_expand)
            self.hidden[:, t, :] = prev_h
            self.cell[:, t, :] = prev_c
        
        if not self.training:
            self.hidden = self.hidden.squeeze(1)
            self.cell = self.cell.squeeze(1)
        
        return self.hidden


class ParallelTable(nn.Module):
    def __init__(self, image_input_dim=4096, image_embed_dim=512, rnn_size=512, vocab_size=512):
        super(ParallelTable, self).__init__()
        # For mapping from image vectors to word vectors
        self.image_encoder = nn.Sequential(
            nn.Linear(image_input_dim, image_embed_dim),
            nn.ReLU(True)
        )
        # For mapping word indices to word vectors
        self.lookup_table = nn.Embedding(vocab_size + 2, rnn_size)
    
    def forward(self, inputs):
        assert isinstance(inputs, tuple), "invalid format of inputs"
        feats, seqs = inputs
        encode_feats = self.image_encoder(feats)
        emd_words = self.lookup_table(seqs)
        return torch.cat([encode_feats.unsqueeze(1), emd_words], dim=1)


class RNN_LanguageModel(nn.Module):
    def __init__(self, opt=None, cfg=None):
        super(RNN_LanguageModel, self).__init__()
        
        vocab_pth = opj(cfg.DATASET.DATA_DIR, cfg.LANGUAGE_MODEL.VOCAB_FILE)
        with open(vocab_pth, "r") as fd:
            region_dicts = json.load(fd)
            cfg.LANGUAGE_MODEL.VOCAB_SIZE = len(region_dicts["vocab"])
            token_to_idx = region_dicts["token_to_idx"]
            self.idx_to_token = region_dicts["idx_to_token"]
            # self.img_id_to_dir = region_dicts["img_id_to_dir"]
            fd.close()
        
        self.vocab_size = len(region_dicts["vocab"])
        self.bos_idx = self.vocab_size + 1
        self.eos_idx = self.vocab_size + 1 
        self.eos_token = "<eos>"
        self.bos_token = "<bos>"
        self.pad_idx = self.vocab_size + 2
        vocab_size = self.pad_idx
        
        self.image_vector_dim = cfg.LANGUAGE_MODEL.FEATS_DIM  
        self.input_encoding_size = cfg.LANGUAGE_MODEL.FEATS_ENCODE_DIM
        self.rnn_type = cfg.LANGUAGE_MODEL.CELL_TYPE  
        self.rnn_size = cfg.LANGUAGE_MODEL.HIDDEN_DIM  
        self.num_layers = cfg.LANGUAGE_MODEL.NUM_LAYERS  
        self.dropout = 0.5  # opt.drop_prob_lm
        self.seq_length = cfg.LANGUAGE_MODEL.SEQUENCE_LEN
        
        # sample from the distribution instead
        self.sample_argmax = True
        
        sequence_core = LSTM(input_dim=self.input_encoding_size, hidden_dim=self.rnn_size)
        dropout = nn.Dropout(self.dropout) if self.dropout else None
        logit = nn.Linear(self.rnn_size, vocab_size + 1)

        rnn = nn.Sequential()
        rnn.add_module("sequence_core", sequence_core)
        rnn.add_module("dropout", dropout)
        rnn.add_module("logit", logit)
        
        input_embedding = ParallelTable(image_input_dim=self.image_vector_dim,
                                        image_embed_dim=self.input_encoding_size,
                                        rnn_size=self.rnn_size, vocab_size=vocab_size)
        self.net = nn.Sequential()
        self.net.add_module("input_embedding", input_embedding)
        # self.net.add_module("rnn", nn.Sequential(*rnn_block))
        self.net.add_module("rnn", rnn)
    
    def pad_sequence(self, gt_sequence):
        N, T = gt_sequence.shape
        gt_pad = gt_sequence.new_zeros(N, T + 1)
        gt_pad[:, 0] = self.bos_idx
        gt_pad[:, 1:T + 1] = gt_sequence
        mask = torch.eq(gt_pad, 0)
        gt_pad[mask] = self.pad_idx
        return gt_pad
    
    def get_gt_caps(self, gt_sequence, gt_lengths):
        N, T = gt_sequence.shape
        if gt_sequence.numel() == 0:
            print(gt_sequence)
        assert gt_sequence.numel() > 0, "panic, failed ground truth"
        target = torch.zeros(N, T + 2, dtype=gt_sequence.dtype)
        target[:, 1:T + 1] = gt_sequence
        target[torch.arange(N), gt_lengths + 1] = self.eos_idx
        return target.to(gt_sequence.device)
    
    def _reset_states(self):
        for i in range(self.num_layers):
            rnn_size = len(self.net.rnn)
            for r in range(rnn_size):
                layer = self.net.rnn[r]
                if type(layer) == LSTM:
                    layer.reset_states()
                    layer.remember_states = True
                    break
    
    def setup_optimizer(self, model):
        import torch.optim
        params = []
        BASE_LR = 0.0005
        WEIGHT_DECAY = 0.0005
        WEIGHT_DECAY_BIAS = 0.0
        BIAS_LR_FACTOR = 2
        BETAS = [0.9, 0.999]
        EPS = 1e-8
        
        for key, value in model.named_parameters():
            if not value.requires_grad:
                print("%s donot need grad" % key)
                continue
            
            lr = BASE_LR
            weight_decay = WEIGHT_DECAY
            if "bias" in key:
                lr = BASE_LR * BIAS_LR_FACTOR
                weight_decay = WEIGHT_DECAY_BIAS
            params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
        optimizer = torch.optim.Adam(
            params,
            lr=BASE_LR,
            betas=BETAS,
            eps=EPS
        )
        
        for p in self.state_dict():
            print(p)

        return optimizer
    
    def forward(self, feats, seqs):
        pad_seqs = self.pad_sequence(seqs)
        ret = self.net((feats, pad_seqs))
        return ret
    
    def sample(self, fc_feats):
       # sample_method = opt.get('sample_method', 'greedy')
        
        batch_size = fc_feats.size(0)
        seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
        # seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
        self._reset_states()
        
        # First timestep: image vectors, ignore output
        image_vecs_encoded = self.net.input_embedding.image_encoder.forward(fc_feats)
        # noised_img = image_vecs_encoded + torch.randn(image_vecs_encoded.shape[0], 512).to("cuda")
        self.net.rnn.forward(image_vecs_encoded.unsqueeze(1))
        
        for t in range(self.seq_length):
            if t == 0:
                words = fc_feats.data.new(batch_size).long().zero_().fill_(self.bos_idx)
            else:
                words = seq[:, t - 1]
            
            wordvecs = self.net.input_embedding.lookup_table.forward(words)
            scores = self.net.rnn.forward(wordvecs.unsqueeze(1)).view(batch_size, -1)
            if self.sample_argmax:
                _, idx = torch.max(scores, 1)
                idx = idx.view(-1).long()
            else:
                logprobs = F.softmax(self.word_logit(scores), dim=1)
                idx = torch.multinomial(logprobs, 1)
                idx = idx.view(-1).long()
            seq[:, t] = idx
        print(seq)
        return seq


if __name__ == "__main__":
    ################################################################################

    loss_fn = TemporalCrossEntropyCriterion()
    toy_model = RNN_LanguageModel(cfg=cfg).to("cuda")
    
    optim = toy_model.setup_optimizer(lm)
    optim.zero_grad()
    for iter in range(1000):
        pred = toy_model.forward(feats=train_feats, seqs=train_seqs)
        
        loss = loss_fn(pred, toy_model.get_gt_caps(train_seqs, train_lens))
        loss.backward()
        if iter % 10 == 0:
              print("Iteration: {} loss: {}".format(iter, loss))
        optim.step()
        optim.zero_grad()
    ################################################################################
    
    toy_model.eval()
    pred = toy_model.sample(test_feats)

It might be helpful to directly have a look the model structure

RNN_LanguageModel(
  (net): Sequential(
    (input_embedding): ParallelTable(
      (image_encoder): Sequential(
        (0): Linear(in_features=4096, out_features=512, bias=True)
        (1): ReLU(inplace=True)
      )
      (lookup_table): Embedding(910, 512)
    )
    (rnn): Sequential(
      (sequence_core): LSTM()
      (dropout): Dropout(p=0.5, inplace=False)
      (logit): Linear(in_features=512, out_features=909, bias=True)
    )
  )
)

Here is the training log:

Iteration: 0 loss: 18.254560470581055
Iteration: 10 loss: 17.90742301940918
Iteration: 20 loss: 17.980499267578125
Iteration: 30 loss: 17.781349182128906
Iteration: 40 loss: 17.84979248046875
Iteration: 50 loss: 17.896930694580078
Iteration: 60 loss: 17.904727935791016
Iteration: 70 loss: 17.801372528076172
Iteration: 80 loss: 18.06501579284668
Iteration: 90 loss: 17.91199493408203
Iteration: 100 loss: 17.90298843383789
Iteration: 110 loss: 17.859905242919922
Iteration: 120 loss: 18.082622528076172
Iteration: 130 loss: 17.864147186279297
Iteration: 140 loss: 18.0921630859375
.....

Seems like it does not intend to converge. The Idea is based on paper DenseCap: Fully Convolutional Localization Networks for Dense Captioning