I am using transformers for a time series forecasting task. The following is my mask code, and transformer code. I am having a few issues, one of which is that the validation MAE is 0, but during inference the model performs very poorly (I think there may be other problems allowing the model to perform perfectly, but I wanted to start here). I figured it may be due to an improper implementation of the mask. Any help would be greatly appreciated. In order to check my hypothesis, I also set the masks to all 0’s, and nothing changed. Any ideas as to why that would be as well?
def mask(dim1: int, dim2: int): return torch.triu(torch.ones(dim1, dim2) * float('-inf'), diagonal=1)
def __init__(self, d_model = int, heads = int, dropout = float, dim_feedforward = int, stack = int, # [Embedding] channel_in = int, window_size = int, pred_size = int ): super(Model, self).__init__() def mask(dim1: int, dim2: int): return torch.triu(torch.ones(dim1, dim2) * float('-inf'), diagonal=1) self.embedding = Embedding(channel_in=channel_in, window_size=window_size) # [Encoder] encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu',batch_first=True, norm_first=True,) self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=stack, norm=nn.LayerNorm(d_model)) # [Mask] self.tgt_mask = mask(pred_size, pred_size).to(DEVICE) self.src_mask = mask(pred_size, window_size).to(DEVICE) # [Decoder] decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True, norm_first=True,) self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=stack, norm=nn.LayerNorm(d_model)) self.out = nn.Linear(d_model, 1) def forward(self, x, tgt): x = self.embedding(x) tgt = self.embedding(tgt) memory = self.encoder(x) out = self.decoder(tgt, memory, self.tgt_mask, self.src_mask) out = self.out(out) return out
P.S. I apologize in advance for the wierd formatting. I am new to pytorch forums, and don’t know how to make it so all the code bits are greyed out. If anyone knows how to do that, it may help in future questions.