TransformerDecoder masks shape error using model.eval()

I’m getting an error during validation on the masks of the TransformerDecoder.
Found this post talking about a similar error on the Encoder, and the solution was to update torch, but it didn’t fix my problem.

Currently using pytorch 2.5.1, cuda 12.1
During training masks work with no issue in both encoder and decoder.
During evaluation however, using model.eval and torch_no_grad i get the following error on the Decoder only:

Mask shape should match input. mask: [16, 1, 350, 350] input: [16, 8, 350, 350]

the shapes im using are
tgt.shape = [batch, 350]
memory.shape = [batch, 350, 512]
tgt_mask.shape = [16,350,350]
memory_key_padding_mask.shape = [16,350]

Which are the exact same shapes used in training, where no error is raised. Changing to model.train() makes the error go away but for validation it’s obviously not ideal.

If I try adding the extra dimension from the error, to make the tgt_mask be [b, 1, 350,350]
the error changes to:

AssertionError: For batched (3-D) query, expected attn_mask to be None, 2-D or 3-D but found 4-D tensor instead

And explicitly adding the number of heads to make the tgt_mask shape [b, 8, 350,350] also gives the same error.

To be precise, the configuration I’m using for the decoder is the following:

x = self.decoder(tgt=x,
                         memory=encOut,
                         tgt_mask=tgtMask,
                         memory_key_padding_mask=srcMask,
                         tgt_is_causal=True)

Thanks for your help

Do you see the same issue using the latest nightly release?
If so, could you post a minimal and executable code snippet reproducing the error, please?

Tested code in nightly, same error.

Before testing the nightly release, I also tested changing the masks.
In the code with errors I sent previoulsy, the tgt_mask was created like this:

def causalMask(size):
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

decMask = (decInputT != self.padToken).unsqueeze(0).unsqueeze(0).int()
decMask = decMask & causalMask(decInputT.size(0))

Meaning it was the causal mask and padding mask together.

I thought maybe this was the source of the problem so I tried separating them, changing the decoder declaration to:

        x = self.decoder(tgt=x,
                         memory=encOut,
                         tgt_mask=tgtCauMask,
                         tgt_key_padding_mask=tgtPadMask, 
                         memory_key_padding_mask=srcMask,
                         tgt_is_causal=True)

To make this decoder run, the shapes of the mask were the following:
tgt.shape = [batch, 350]
memory.shape = [batch, 350, 512]
tgt_mask.shape = [350,350]
tgt_key_padding_mask.shape = [b,350]
memory_key_padding_mask.shape = [16,350]

This way the code didn’t raise any errors but the validation loops all produce Nan values.

Am I doing something wrong with the masks?

I tried making a sample code to reproduce the error but im getting a different error:

RuntimeError: value cannot be converted to type int64_t without overflow

Anyway here’s the code, I checked sizes and data types are the same as the original code, yet this new error appears only here.

# -*- coding: utf-8 -*-
"""
Created on Mon Jan 27 15:12:41 2025

@author: Mateo-drr
"""

import torch
import torch.nn as nn
from torch.nn.functional import pad

# Configuration
batch_size = 16
seq_len = 350
embed_dim = 512
num_heads = 8
pad_token = 0

def causal_mask(size):
    """Create a causal mask for the decoder."""
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

# Dummy data
tgt = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.int64)
memory = torch.rand(batch_size, seq_len, embed_dim).to(torch.float32)

# Masks
tgt_padding_mask=[]
tgt_combined_mask=[]
tgt_causal_mask = causal_mask(seq_len)  # [seq_len, seq_len]
for i in range(batch_size):
    padmask = (tgt[i] != pad_token).unsqueeze(0).unsqueeze(0).int() 
    
    # Combine causal and padding masks
    combmask = padmask & tgt_causal_mask

    tgt_padding_mask.append(padmask)
    tgt_combined_mask.append(combmask)

tgt_padding_mask = torch.stack(tgt_padding_mask)
tgt_combined_mask = torch.stack(tgt_combined_mask)
tgt_causal_mask = tgt_causal_mask.repeat(16,1,1)


def make_decoder(embed_dim, num_heads):
    """Simple TransformerDecoder initialization."""
    decoder_layer = nn.TransformerDecoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        batch_first=True
    )
    return nn.TransformerDecoder(decoder_layer, num_layers=6)

# Model
decoder = make_decoder(embed_dim, num_heads)

#make masks boolean
tgt_padding_mask = tgt_padding_mask.bool()
mem_padding_mask = tgt_padding_mask.clone()
tgt_causal_mask = tgt_causal_mask.bool()
tgt_combined_mask = tgt_combined_mask.bool()

#remove dimensions
tgt_combined_mask = tgt_combined_mask.squeeze(1) #[b,1,350,350] -> [b,350,350]

'''
Test 1 - separate masks
'''

print('tgt', tgt.shape, tgt.type())
print('mem',memory.shape, memory.type())
print('pad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())
print('cau',tgt_causal_mask[0].shape, tgt_causal_mask.type())
print('com',tgt_combined_mask.shape, tgt_combined_mask.type())
print('mempad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())

decoder.train()
output = decoder(
    tgt=tgt,  
    memory=memory,  
    tgt_mask=tgt_causal_mask[0],  
    tgt_key_padding_mask = tgt_padding_mask.squeeze(1).squeeze(1),
    memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),  
    tgt_is_causal=True
)

decoder.eval()
with torch.no_grad():
    output = decoder(
        tgt=tgt,  
        memory=memory,  
        tgt_mask=tgt_causal_mask[0],  
        tgt_key_padding_mask = tgt_padding_mask.squeeze(1).squeeze(1),
        memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),  
        tgt_is_causal=True
    )


'''
Test 2 - joint masks
'''

print('\ntgt', tgt.shape, tgt.type())
print('mem',memory.shape, memory.type())
print('pad',tgt_padding_mask.shape, tgt_padding_mask.type())
print('cau',tgt_causal_mask.shape, tgt_causal_mask.type())
print('com',tgt_combined_mask.shape, tgt_combined_mask.type())
print('mempad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())

decoder.train()
output = decoder(
    tgt=tgt,  
    memory=memory,  
    tgt_mask=tgt_combined_mask,  
    # tgt_key_padding_mask = tgt_padding_mask,
    memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),  
    tgt_is_causal=True
)

decoder.eval()
with torch.no_grad():
    output = decoder(
        tgt=tgt,  
        memory=memory,  
        tgt_mask=tgt_combined_mask,  
        # tgt_key_padding_mask = tgt_padding_mask,
        memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),  
        tgt_is_causal=True
    )


Thank you for your help