I’m trying to train a Transformer Decoder to generate textual captions from a CLIP embedding in an autoregressive way. I followed this official Pytorch Tutorial as a base for my project. However, when I train the model the loss decreases, but at inference it always predicts the same token for all the positions in the sequence. I have further debugged my model and during training time it also does the same. It seems that the Positional Encoding or the tgt_mask
are not working as expected.
My Transformer Decoder:
class TextDecoder_TRANSFORMER(nn.Module):
def __init__(self,latent_dim=512, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", **kargs):
super().__init__()
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=activation)
self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
num_layers=self.num_layers)
self.finallayer = nn.Linear(self.latent_dim, 49408)
def decode(self, tgt, memory, tgt_mask):
return self.seqTransDecoder(self.sequence_pos_encoder(tgt), memory, tgt_mask)
def forward(self, batch):
# z is the CLIP embedding
z, mask, padding_mask = batch["z"], batch["text_mask"], batch["text_padding_mask"]
# Text Captions
captions = batch["clip_text_embedding"]
captions = self.sequence_pos_encoder(captions)
output = self.seqTransDecoder(tgt=captions,
tgt_mask=mask,
tgt_key_padding_mask=padding_mask,
memory=z)
output = self.finallayer(output)
batch["output_caption"] = output
return batch
The Positional Encoding Used:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
And the code used to encode the textual information feed to the model
# Extracted from: https://pytorch.org/tutorials/beginner/translation_transformer.html
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones((sz, sz), device=self.device)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(self, tokens):
tk_seq_len = tokens.shape[0]
tk_mask = self.generate_square_subsequent_mask(tk_seq_len)
tk_padding_mask = (tokens == 0).transpose(0, 1)
return tk_mask, tk_padding_mask
def encode_text(self, batch):
captions = batch["clip_text"]
captions_tensor = []
for caption in captions:
caption = clip.tokenize(caption).to(self.device)
captions_tensor.append(caption)
captions_tensor = torch.cat(captions_tensor, dim=0).to(self.device)
captions_tensor = captions_tensor.permute(1,0)
batch["text_tokens"] = captions_tensor
mask, padding_mask = self.create_mask(captions_tensor[:-1,:])
batch["text_mask"] = mask
batch["text_padding_mask"] = padding_mask
caption_embedding = self.clip_model.token_embedding(captions_tensor[:-1,:]).type(self.clip_model.dtype)
caption_embedding.to(self.device)
batch["clip_text_embedding"] = caption_embedding
return batch
# function to generate output sequence using greedy algorithm
def greedy_decode(self,text, memory, max_len):
# The tokenizer is the SimpleTokenizer from the CLIP repository
START_IDX = self.tokenizer.encoder['<|startoftext|>']
EOS_IDX = self.tokenizer.encoder['<|endoftext|>']
ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(self.device)
for i in range(max_len-1):
memory = memory.to(self.device)
tgt_mask = (self.generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(self.device)
token_embs = self.clip_model.token_embedding(ys)
out = self.textDecoder.decode(token_embs, memory, tgt_mask)
out = out.transpose(0, 1)
prob = self.textDecoder.finallayer(out[0,:,:])
_, next_word = torch.max(prob, dim=1)
next_word = next_word[-1].item()
ys = torch.cat([ys, torch.ones(1, 1).type(torch.long).fill_(next_word).to(self.device)], dim=0)
if next_word == EOS_IDX:
break
return ys
I’ve been days trying to figure out what is wrong, but I haven’t been able to find it. I would really appreciate your help and wisdom.