Train transformer without teacher forcing

Has anyone tried training transformer without teacher forcing? I was thinking of doing it but could not understand how to implement it in batch.
Here is my code, It works with batch size=1.

with torch.no_grad():
        for batch in test_loader:
            src, trg = batch
            imgs.append(src.flatten(0,1))
            src, trg = src.cuda(), trg.cuda()     
            memory = get_memory(model,src.float())                
            out_indexes = [tokenizer.chars.index('SOS'), ]            
            for i in range(max_text_length):
                mask = model.generate_square_subsequent_mask(i+1).to(device)
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                output = model.vocab(model.transformer_decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == tokenizer.chars.index('EOS'):
                    break

I’m trying transformer without teacher forcing for time series prediction and I came out with this snippet. Theoretically it works fine, but it’s tremendously slow and it keep crashing due to cuda running out of memory as soon as I use a batch size greater than 3,4 samples.

# src shape is [n_batch, src_length, d_model] --> [N, S, D]
# tgt and gt (groundtruth) shape is [n_batch, tgt_length, d_model] --> [N, T, D]

l = 0.
for src, tgt, gt in loader:
    src, tgt, gt = src.to(device), tgt.to(device), gt.to(device)
    optimizer.zero_grad()

    # first tgt element is the last src one, followed by a zero-filled tensor
    tgt = torch.cat((src[:, -1:, :], torch.zeros(tgt.shape[0], tgt.shape[1] - 1, tgt.shape[2], device=device)), dim=1)

    out = torch.zeros_like(tgt, device=device)
    memory = model.encode(src)
    for j in range(gt.shape[1]):
        out = model.decode(tgt, memory)
        l += loss(out[:, j, :], gt[:, j, :])
        tgt = tgt.detach()
        tgt = torch.cat((tgt[:, :j + 1, :], out[:, j:j+1, :], tgt[:, j + 1:gt.shape[1] - 1, :]), dim=1) 
    l.backward()
1 Like