encoder_output = self.encoder(encoder_input, padding_mask)
decoder_output = self.decoder(decoder_input, encoder_output, padding_mask, future_mask)
decoder_output = decoder_output[:, :-1, :]
decoder_input = decoder_input[:, 1:]
batch_loss = criterion(
decoder_output.reshape(-1, decoder_output.size(-1)),
decoder_input.reshape(-1).long(),
)
return batch_loss, torch.argmax(decoder_output, dim=-1)
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
params = list(transformers.parameters())
optimizer = optim.Adam(params, lr=0.01, weight_decay=0.1, betas=(0.9, 0.98), eps=1e-09)
num_epochs = 1
scheduler = LrSchedule(
transformers.hidden_dim, factor=1, warmup=4000, optimizer=optimizer,
)
for epoch in range(num_epochs):
for batch in dataloader:
ids_src_tensor, ids_trg_tensor, _ = batch
ids_src_tensor = ids_src_tensor.to(device)
ids_trg_tensor = ids_trg_tensor.to(device)
padding_mask = create_padding_mask(ids_src_tensor).to(device)
future_mask = create_future_mask(ids_trg_tensor.shape[1]).to(device)
batch_loss, out = transformers.train(ids_src_tensor, ids_trg_tensor, padding_mask, future_mask, criterion)
batch_loss.backward()
scheduler.step()
scheduler.optimizer.zero_grad()
print(scheduler._rate)
torch.nn.utils.clip_grad_norm_(transformers.parameters(), max_norm=1.0)
print(torch.sum(out, dim = -1))
print(batch_loss)