Has anyone tried training transformer without teacher forcing? I was thinking of doing it but could not understand how to implement it in batch.
Here is my code, It works with batch size=1.
with torch.no_grad():
for batch in test_loader:
src, trg = batch
imgs.append(src.flatten(0,1))
src, trg = src.cuda(), trg.cuda()
memory = get_memory(model,src.float())
out_indexes = [tokenizer.chars.index('SOS'), ]
for i in range(max_text_length):
mask = model.generate_square_subsequent_mask(i+1).to(device)
trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
output = model.vocab(model.transformer_decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
out_token = output.argmax(2)[-1].item()
out_indexes.append(out_token)
if out_token == tokenizer.chars.index('EOS'):
break