Pytorch Transformer Tutorial - What is tgt_input = tgt[:-1, :] supposed to do?

https://pytorch.org/tutorials/beginner/translation_transformer.html

The training loop in the above link:

from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))

I dont understand what exactly the parts tgt_input = tgt[:-1, :] and tgt_out = tgt[1:, :] achieve.

It seems to take in tgt as input excluding the last batch, and then using tgt again for the labels but excluding the first batch? If this is what it does, then what purpose does it serve?

After some digging, it turns out that those were for shifting the target tokens. During training, the outputs that are given to the decoder part of the transformer as input are shifted left.

decoder input: tgt_input = tgt[:-1, :] → shifts the decoder input to the left.
labels: tgt_out = tgt[1:, :} → shifts the output to the right for calculating the loss.

However, it was not clear from the tutorial that the batch dimension is the second dimension of tgt, which was what confused me. In my dataset, I have tgt as tgt.shape = [ batch_dim, sequence_length ], and the 3rd dimension is for embedding like this: [ batch_dim, seq_len, embedding_dim ]. So when I followed the tutorials training function, it was shifting the batches, not the tokens, which made 0 sense.

I hope this helps if someone else gets confused about it.

1 Like