Has anyone tried training transformer without teacher forcing? I was thinking of doing it but could not understand how to implement it in batch.
Here is my code, It works with batch size=1.
for batch in test_loader:
src, trg = batch
src, trg = src.cuda(), trg.cuda()
memory = get_memory(model,src.float())
out_indexes = [tokenizer.chars.index('SOS'), ]
for i in range(max_text_length):
mask = model.generate_square_subsequent_mask(i+1).to(device)
trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
output = model.vocab(model.transformer_decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
out_token = output.argmax(2)[-1].item()
if out_token == tokenizer.chars.index('EOS'):
I’m trying transformer without teacher forcing for time series prediction and I came out with this snippet. Theoretically it works fine, but it’s tremendously slow and it keep crashing due to cuda running out of memory as soon as I use a batch size greater than 3,4 samples.
# src shape is [n_batch, src_length, d_model] --> [N, S, D]
# tgt and gt (groundtruth) shape is [n_batch, tgt_length, d_model] --> [N, T, D]
l = 0.
for src, tgt, gt in loader:
src, tgt, gt = src.to(device), tgt.to(device), gt.to(device)
# first tgt element is the last src one, followed by a zero-filled tensor
tgt = torch.cat((src[:, -1:, :], torch.zeros(tgt.shape, tgt.shape - 1, tgt.shape, device=device)), dim=1)
out = torch.zeros_like(tgt, device=device)
memory = model.encode(src)
for j in range(gt.shape):
out = model.decode(tgt, memory)
l += loss(out[:, j, :], gt[:, j, :])
tgt = tgt.detach()
tgt = torch.cat((tgt[:, :j + 1, :], out[:, j:j+1, :], tgt[:, j + 1:gt.shape - 1, :]), dim=1)
cuda issues was probably caused by
torch.autograd.set_detect_anomaly(True), which I was set for debugging purpose. The code should run fine.
Thank you for your share. I want to know the meaning of the d_model. I adopt BART to try your code, but the size of input_ids is [n_batch, seq_length] but there is no d_model. Can you help me?
I’m not an expert with BART, but I suppose what you’re referring to are the indexes of the embedding employed by the model. Each of those seq_length item is an index for a vocabulary of N embeddings, each of which being a tensor of size d_model. Those will make the input to your model. This may be an helpful read.