Transformer give zero train and val loss in just 3-4 epochs!

I am trying to build an architecture (CNN+transformer), but it looks like the transformer is doing way too good, as it takes only 3 epochs to zero the losses. My assumption is that in some way, it is able to see the entire sequence at once, but I am not able to figure it out as I am providing “tgt_mask”.

import math
import torch
import torch.nn as nn
from models.encoding.positional_encoding_for_xfmer import PositionalEncoding


class Transformer(nn.Module):

    def __init__(self, output_dim, dec_hid_dim, max_len, nheads, n_encoder_layers,
                    n_decoder_layers, dropout, device):

        super(Transformer, self).__init__()
        self.pos = PositionalEncoding(dec_hid_dim, dropout, max_len)
        self.device=device
        self.output_dim = output_dim
        self.dec_hid_dim = dec_hid_dim
        self.embed = nn.Embedding(output_dim, dec_hid_dim)
        self.xfmer = nn.Transformer(
                                d_model=dec_hid_dim,
                                nhead=nheads,
                                num_encoder_layers=n_encoder_layers,
                                num_decoder_layers=n_decoder_layers,
                                dropout=dropout,
                                )
        self.modify_len = nn.Linear(330, max_len)
        self.out = nn.Linear(dec_hid_dim, output_dim)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, trg, sos_idx, pad_idx, is_test):
        # src : (B, L, dec_hid_dim)
        src = self.modify_len(src.permute(0,2,1)).permute(2,0,1)
        src = src * math.sqrt(self.dec_hid_dim)
        trg = trg.permute(1,0)  # (max_len, B)
        trg = self.embed(trg) * math.sqrt(self.dec_hid_dim)
        src = self.pos(src) # (max_len, B, dec_hid_dim)
        trg = self.pos(trg) # (max_len, B, dec_hid_dim)
        trg_mask = self._generate_square_subsequent_mask(trg.shape[0]).to(self.device)

        xfmer_dec_outputs = self.xfmer(src, trg, tgt_mask=trg_mask)
        xfmer_dec_outputs = self.out(xfmer_dec_outputs)

        return xfmer_dec_outputs.permute(1,0,2)

Loss plot:

Any suggestion will be helpful. Thank you in advance!