I tried making a sample code to reproduce the error but im getting a different error:
RuntimeError: value cannot be converted to type int64_t without overflow
Anyway here’s the code, I checked sizes and data types are the same as the original code, yet this new error appears only here.
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 27 15:12:41 2025
@author: Mateo-drr
"""
import torch
import torch.nn as nn
from torch.nn.functional import pad
# Configuration
batch_size = 16
seq_len = 350
embed_dim = 512
num_heads = 8
pad_token = 0
def causal_mask(size):
"""Create a causal mask for the decoder."""
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
return mask == 0
# Dummy data
tgt = torch.randint(0, 100, (batch_size, seq_len), dtype=torch.int64)
memory = torch.rand(batch_size, seq_len, embed_dim).to(torch.float32)
# Masks
tgt_padding_mask=[]
tgt_combined_mask=[]
tgt_causal_mask = causal_mask(seq_len) # [seq_len, seq_len]
for i in range(batch_size):
padmask = (tgt[i] != pad_token).unsqueeze(0).unsqueeze(0).int()
# Combine causal and padding masks
combmask = padmask & tgt_causal_mask
tgt_padding_mask.append(padmask)
tgt_combined_mask.append(combmask)
tgt_padding_mask = torch.stack(tgt_padding_mask)
tgt_combined_mask = torch.stack(tgt_combined_mask)
tgt_causal_mask = tgt_causal_mask.repeat(16,1,1)
def make_decoder(embed_dim, num_heads):
"""Simple TransformerDecoder initialization."""
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dim,
nhead=num_heads,
batch_first=True
)
return nn.TransformerDecoder(decoder_layer, num_layers=6)
# Model
decoder = make_decoder(embed_dim, num_heads)
#make masks boolean
tgt_padding_mask = tgt_padding_mask.bool()
mem_padding_mask = tgt_padding_mask.clone()
tgt_causal_mask = tgt_causal_mask.bool()
tgt_combined_mask = tgt_combined_mask.bool()
#remove dimensions
tgt_combined_mask = tgt_combined_mask.squeeze(1) #[b,1,350,350] -> [b,350,350]
'''
Test 1 - separate masks
'''
print('tgt', tgt.shape, tgt.type())
print('mem',memory.shape, memory.type())
print('pad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())
print('cau',tgt_causal_mask[0].shape, tgt_causal_mask.type())
print('com',tgt_combined_mask.shape, tgt_combined_mask.type())
print('mempad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())
decoder.train()
output = decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_causal_mask[0],
tgt_key_padding_mask = tgt_padding_mask.squeeze(1).squeeze(1),
memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),
tgt_is_causal=True
)
decoder.eval()
with torch.no_grad():
output = decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_causal_mask[0],
tgt_key_padding_mask = tgt_padding_mask.squeeze(1).squeeze(1),
memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),
tgt_is_causal=True
)
'''
Test 2 - joint masks
'''
print('\ntgt', tgt.shape, tgt.type())
print('mem',memory.shape, memory.type())
print('pad',tgt_padding_mask.shape, tgt_padding_mask.type())
print('cau',tgt_causal_mask.shape, tgt_causal_mask.type())
print('com',tgt_combined_mask.shape, tgt_combined_mask.type())
print('mempad',tgt_padding_mask.squeeze(1).squeeze(1).shape, tgt_padding_mask.type())
decoder.train()
output = decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_combined_mask,
# tgt_key_padding_mask = tgt_padding_mask,
memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),
tgt_is_causal=True
)
decoder.eval()
with torch.no_grad():
output = decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_combined_mask,
# tgt_key_padding_mask = tgt_padding_mask,
memory_key_padding_mask=mem_padding_mask.squeeze(1).squeeze(1),
tgt_is_causal=True
)
Thank you for your help