I am using CNN-transformer hybrid architecture to detect handwritten equation and convert them to LaTex strings. All target sequences (the actual LaTex representation of a handwritten equation) are padded/truncated to maxSeqLen
(max. seqeunce length).
This is how I generate the target sequence’s mask and target sequence’s padding mask:
def generate_tgt_mask(maxSeqLen: int) -> torch.Tensor :
mask = (torch.triu(torch.ones((maxSeqLen, maxSeqLen))) == 1)
mask = mask.transpose(0, 1).float()
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, float(0.0))
return mask
def generate_padding_mask(tgt, padTokenIdx = 0) -> torch.Tensor:
return (tgt == padTokenIdx)
Training loop snippet:
for idx, (src, tgt, trainEq, trainFname) in enumerate(trainDataLoader):
currTrainBatchAcc = 0
currTrainBatchLoss = 0
trainSrc, trainTgt = src.cuda(), tgt.cuda()
tgt_padding_mask = generate_padding_mask(trainTgt, tokenizer.vocab['[PAD]']).cuda()
model.train()
trainPred: torch.Tensor = model(trainSrc, trainTgt, tgt_mask, tgt_padding_mask)
trainLoss = loss_fn(trainPred.view(-1, vocabSize), trainTgt.view(-1))
opt.zero_grad()
trainLoss.backward()
opt.step()
Inference code:
model.eval()
with torch.inference_mode():
tgt_mask: torch.Tensor = generate_tgt_mask(maxSeqLen).cuda()
for idx, (src, tgt, eq, fPath) in enumerate(valDataLoader):
src, tgt = src.cuda(), tgt.cuda()
inf_tgt = torch.ones((1, maxSeqLen), dtype=torch.int64).to(device)
inf_tgt[:,0] = tokenizer.vocab['[SOS]']
inf_tgt[:,1:] *= tokenizer.vocab['[PAD]']
for i in range(1, maxSeqLen):
inf_tgt_padding_mask = generate_padding_mask(inf_tgt, padTokenIdx=0).to(device)
pred: torch.Tensor = model(src, tgt, tgt_mask, inf_tgt_padding_mask)
predToken = torch.argmax(pred.softmax(dim= -1), dim= -1)
inf_tgt[0,i] = predToken[0,i]
if(predToken[0,i].cpu() == tokenizer.vocab['[EOS]']):
break
break
I have tried setting the target sequence mask and target sequence padding mask to None, and it still doesn’t work. I have also checked the mask generation functions and they seem to be fine as well. However, when I pass the actual target sequence during inference, the model works fine, and it predicts the sequence properly. I suspect the model is cheating somehow.
Why is this happening and what do I do?