https://pytorch.org/tutorials/beginner/translation_transformer.html
The training loop in the above link:
from torch.utils.data import DataLoader
def train_epoch(model, optimizer):
model.train()
losses = 0
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
for src, tgt in train_dataloader:
src = src.to(DEVICE)
tgt = tgt.to(DEVICE)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
optimizer.zero_grad()
tgt_out = tgt[1:, :]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
losses += loss.item()
return losses / len(list(train_dataloader))
I dont understand what exactly the parts tgt_input = tgt[:-1, :]
and tgt_out = tgt[1:, :]
achieve.
It seems to take in tgt as input excluding the last batch, and then using tgt again for the labels but excluding the first batch? If this is what it does, then what purpose does it serve?