Encoder-decoder transformer poor performance during inference

I am working on a full encoder-decoder transformer model to synthesize speech from EEG signals. Specifically, for a window of EEG activity of length x=100, I predict a window of length x =100 of mel spectrograms. The EEG and mel spectrograms are aligned in time, with total data set dimensions (43265, 107) for EEG and (43264, 80) for mel spectrograms.I divided the dataset into training and testing sets with an 80/20 split. This results in 6902 training sequences, each with dimensions (100, 107) for EEG and (100, 80) for mel spectrograms.My model architecture includes:

  • Two prenets (one for the encoder and one for the decoder) to extract features from the EEG and mel spectrograms, projecting them into embeddings.
  • A postnet to refine the predicted mel spectrograms.

The issue I’m facing is that while the training loss decreases, the model performs poorly during inference. The predictions on the validation set are very poor, and the model also underperforms on the training set during inference.During inference I predict the data the following way:

eeg_val = eeg_val.to(device) 
mel_val = mel_val.to(device) 
mel_input = torch.zeros([modelArgs.batch_size, 1, 80]).to(device) 
pos_eeg = torch.arange(1, eeg_context_length + 1).repeat(modelArgs.batch_size, 1).to(device)
pbar = tqdm(range(config["TR"]["context_length"]), desc=f"Validating...", position=0, leave=False) 
with torch.no_grad(): 
    for _ in pbar: pos_mel = torch.arange(1, mel_input.size(1)+1).repeat(modelArgs.batch_size, 1).to(device) 
        mel_out, postnet_pred, attn, _, attn_dec = model.forward(eeg_val, mel_input, pos_eeg, pos_mel) 
        mel_input = torch.cat([mel_input, mel_out[:,-1:,:]], dim=1) 
batch_loss = criterion(postnet_pred, mel_val)

where:

  • config["TR"]["context_length"] is the length of the window i.e. 100
  • pos_eeg and pos_mel are used to create masks for attention
  • mel_out is the output of the decoder, postnet_pred is the output of the postnet

Training overview
training

The loss is calculated with nn.L1Loss() on the output of the decoder and the output of the postnet: batch_loss = mel_loss + post_mel_loss

My model is based on the model from Neural Speech Synthesis with Transformer Network (2018) paper and I use the following implementationThe only difference between my setup and the text2Speech is that:

  • I use EEG instead of text
  • I don’t have a stop token as I predict for a fixed time window.
  • I create positional embeddings using the following class instead of the nn.Embeddings module:
class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, dropout=0.1, max_len=200): 
        super(PositionalEncoding, self).__init__() 
        self.dropout = nn.Dropout(p=dropout) 
        self.alpha = nn.Parameter(t.ones(1)) 
        pe = t.zeros(max_len, d_model) 
        position = t.arange(0, max_len, dtype=t.float).unsqueeze(1) 
        div_term = t.exp(t.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 
        pe[:, 0::2] = t.sin(position * div_term) 
        pe[:, 1::2] = t.cos(position * div_term)
        self.register_buffer('pe', pe) 

    def forward(self, x): 
        pos = self.pe[:x.shape[1]] 
        pos = t.stack([pos]*x.shape[0], 0) # [bs x seq_len(x) x n_pos] 
        x = pos * self.alpha + x 
        return self.dropout(x)

I also use NoamOpt learning scheduler from this tutorial: https://nlp.seas.harvard.edu/2018/04/03/attention.html

Question: What surprises me the most is that despite a drastically increasing loss and very good correlation scores (between prediction and ground truth) during training, the network performs very poorly on the same training sequences during inference → What could be the reason for the model’s poor autoregressive performance?

Could you describe the input data processing during training vs. inference, as I guess your training data might give the model some “hints” towards the target while the inference use case might not?

@ptrblck During training I create the triangular mask for the targets and shift the targets by one as it is described in the original transformer paper.

start_token = torch.zeros((1, 1, 80)).to(device)
for batch_idx, (eeg, mel) in enumerate(train_loader):
        eeg = eeg.to(device)
        mel = mel.to(device)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(mel.shape[1]).to(device)
        mel_shf = create_decoder_input(mel, start_token)
        mel_pred = model(eeg, mel_shf, tgt_mask, tgt_is_causal=True)
        batch_loss = criterion(mel_pred, mel)
        ...

Where create_decoder_input simply shifts the mel spectrograms by one:

def create_decoder_input(target_seq, start_token):
    batch_size = target_seq.size(0)
    start_tokens = start_token.repeat(batch_size, 1, 1).to(target_seq.device)  # Shape: (batch_size, 1, mel_spec_frame_dim)
    shifted_target = torch.cat([start_tokens, target_seq[:, :-1, :]], dim=1)  # Shifted right by one position
    return shifted_target

I use PyTorch TransformerEncoder and TransformerDecoder classes in the model.

During inference I first encode the source and then for the length of the output sequence I predict each frame of the mel spectrogram, appending the output of the decoder to it input.

def inference(eeg_data, model, max_length=200): 
   model.eval()
   with torch.no_grad():
        # Step 1: Encode EEG Data
        encoded_eeg = model.encode(eeg_data)
        
        # Step 2: Initialize Decoder with start token or initial frame
        decoder_input = torch.zeros((eeg_data.size(0), 1, 80)).to(eeg_data.device)
        pbar = tqdm(range(max_length), desc=f"Validating...", position=0, leave=False)
        for _ in pbar:
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_input.shape[1])

            # Step 3: Autoregressive Decoding
            mel = model.decode(decoder_input, encoded_eeg, tgt_mask=tgt_mask, tgt_is_causal=True)
           
            # Select the last frame in the sequence (most recent prediction)
            predicted_frame = mel[:, -1:, :]  # Shape: (batch_size, 1, mel_spec_frame_dim)
            
            # Update decoder input to be the predicted frame
            decoder_input = torch.cat([decoder_input, predicted_frame], dim=1)  # Append new frame to input sequence
        
    return mel