Transformer Training

I am training a transformer model I wrote from scratch for machine translation, and debugging with a very small data set (1000 sentences for training, 200 for dev). With 20 epochs and batch size of 64, the loss of the training and dev set follow the exact same pattern. Seems kind of strange. Anyone know what’s going on here? (attached training loop below as well).

mask = torch.tril(torch.ones((MAX_LENGTH, MAX_LENGTH)))
# optimization loop 
loss_fn = torch.nn.CrossEntropyLoss() 
train_losses = []
val_losses = []
for epoch in range(1,EPOCHS+1):

    # train loop 
    for i, (src,trg) in enumerate(train_data):

        # place tensors to device 
        src = torch.Tensor(src).to(DEVICE).long()
        trg = torch.Tensor(src).to(DEVICE).long()

        # forward pass 
        out = model(src,trg, mask)

        # compute loss 
        train_loss = loss_fn(out.view(-1,tgt_vocab), trg.view(-1))
        # backprop 

        # update weights 
    val_loss = 0
    num_batches = len(dev_data)

    for i, (src, trg) in enumerate(dev_data):

        # place tensors on device 
        src = torch.Tensor(src).to(DEVICE).long()
        trg = torch.Tensor(src).to(DEVICE).long()

        # forward pass 
        out = model(src, trg, mask)
        # compute loss 
        loss_val = loss_fn(out.view(-1,tgt_vocab),trg.view(-1))
        val_loss += loss_val.item()

    val_loss /= num_batches
    print(f'Epoch[{epoch}/{EPOCHS}] train_loss: {train_loss.item()} val_loss: {val_loss}')

I think it’s a good sign to see a similar behavior between the training and validation losses. Could you describe your concern a bit more and what you would expect to see?

I would expect to see some overfitting at some point, with the validation loss starting to increase, but its pattern matches the train loss almost exactly, it seemed unusual.

I would expect to see the same, but note that your training loss is also stuck when you stopped the training and your model does not “perfectly overfit” the training set. Generally, I would expect to see a divergence in the validation loss if your model is able to reduce the training loss towards ~zero.

1 Like