I am trying to build an architecture (CNN+transformer), but it looks like the transformer is doing way too good, as it takes only 3 epochs to zero the losses. My assumption is that in some way, it is able to see the entire sequence at once, but I am not able to figure it out as I am providing “tgt_mask”.
import math
import torch
import torch.nn as nn
from models.encoding.positional_encoding_for_xfmer import PositionalEncoding
class Transformer(nn.Module):
def __init__(self, output_dim, dec_hid_dim, max_len, nheads, n_encoder_layers,
n_decoder_layers, dropout, device):
super(Transformer, self).__init__()
self.pos = PositionalEncoding(dec_hid_dim, dropout, max_len)
self.device=device
self.output_dim = output_dim
self.dec_hid_dim = dec_hid_dim
self.embed = nn.Embedding(output_dim, dec_hid_dim)
self.xfmer = nn.Transformer(
d_model=dec_hid_dim,
nhead=nheads,
num_encoder_layers=n_encoder_layers,
num_decoder_layers=n_decoder_layers,
dropout=dropout,
)
self.modify_len = nn.Linear(330, max_len)
self.out = nn.Linear(dec_hid_dim, output_dim)
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src, trg, sos_idx, pad_idx, is_test):
# src : (B, L, dec_hid_dim)
src = self.modify_len(src.permute(0,2,1)).permute(2,0,1)
src = src * math.sqrt(self.dec_hid_dim)
trg = trg.permute(1,0) # (max_len, B)
trg = self.embed(trg) * math.sqrt(self.dec_hid_dim)
src = self.pos(src) # (max_len, B, dec_hid_dim)
trg = self.pos(trg) # (max_len, B, dec_hid_dim)
trg_mask = self._generate_square_subsequent_mask(trg.shape[0]).to(self.device)
xfmer_dec_outputs = self.xfmer(src, trg, tgt_mask=trg_mask)
xfmer_dec_outputs = self.out(xfmer_dec_outputs)
return xfmer_dec_outputs.permute(1,0,2)
Loss plot:
Any suggestion will be helpful. Thank you in advance!