Network forgets everything during learning (strange behavior of loss function)

Hi,

I am training a model using Pytorch and Pytorch Geometric and I’ve noticed a weird behavior of the loss function:

It seems that during training the network forgets everything. Moreover, the more samples I have in the dataset, the more often this kind of “reset” occurs (and when the dataset is large, the network doesn’t learn at all).

These are the optimizer and the loss function that I am using

criterion = nn.BCEWithLogitsLoss()

lr_start = 2e-4
l2_weight = 1e-10 # L2 regularization weight
optimizer = optim.Adam(net.parameters(), lr=lr_start, weight_decay=l2_weight)

lr_decay = torch.tensor(0.99, dtype=torch.float64)
lr_decay_steps = torch.tensor(5, dtype=torch.float64)
gamma = torch.pow(lr_decay, torch.div(1, lr_decay_steps))

scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma, last_epoch=-1)

I am not an expert of neural network and Pytorch, so I was wondering if maybe there are some mistakes in the code above or what else could cause this issue…

If you could share the code for your training loop then I can try debugging it.

This is my training loop:

    for epoch in range(load_epoch, load_epoch+n_epochs):
        start = time.time()
        loss, accuracy = train_epoch(net,
                                     optimizer,
                                     criterion,
                                     device,
                                     train_dataloader,
                                     clip_val
                                     )
        end = time.time()

        wandb.log({'training loss':loss})
        wandb.log({'training accuracy':accuracy})

        if (epoch+1) % print_every == 0:
            logger.info(f"Training loss at epoch {epoch+1} : {loss}")
            logger.info(f"Training accuracy at epoch {epoch+1} : {accuracy}")
            logger.info(f"Runtime for one epoch : {end-start}")

        if (epoch+1) % test_every == 0:
            start = time.time()
            val_loss, accuracy = test(net, device, test_dataloader, criterion)
            end = time.time()
            wandb.log({'validation loss':val_loss})
            wandb.log({'validation accuracy':accuracy})

            logger.info(f"Validation loss at epoch {epoch+1} : {val_loss}")
            logger.info(f"Validation accuracy at epoch {epoch+1} : {accuracy}")
            logger.info(f"Runtime for testing : {end-start}")

        if (epoch+1) % save_every == 0:
            logger.info(f"Saving training at epoch {epoch+1}")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
                'losses': losses,
                'val_losses': val_losses,
                'accuracies': accuracies,
                'val_accuracies': val_accuracies
                }, join(save_folder,"epoch_{:04d}.tar".format(epoch+1)))

where:

def train_epoch(net, optimizer, loss, device, dataloader, clip_val):
    # training step on all dataset
    net.train()

    running_loss = 0.0
    accuracy = 0.0
    for batch in dataloader:
        # zero the parameter gradients
        optimizer.zero_grad()

        batch.to(device)
        pred = net(batch)
        pred_binary = torch.sigmoid(pred.squeeze()) >= 0.5
        pred_correct = pred_binary == batch.y.type_as(pred)
        accuracy_temp = torch.sum(pred_correct)/len(pred_correct)

        accuracy += accuracy_temp.item()

        logit = loss(pred.squeeze(), batch.y.type_as(pred))
        logit.backward()

        if clip_val is not None:
            torch.nn.utils.clip_grad_value_(net.parameters(), clip_val)
        optimizer.step()

        running_loss += logit.detach().item() 

    running_loss /= len(dataloader)
    accuracy /= len(dataloader)

    return running_loss, accuracy

def test(net, device, test_dataloader, val_criterion):
    net.eval()
    net = net.to(device)

    with torch.no_grad():
        val_loss = 0.0
        accuracy = 0.0

        for batch in test_dataloader:
            batch.to(device)
            pred = net(batch)
            pred_binary = torch.sigmoid(pred.squeeze()) >= 0.5
            pred_correct = pred_binary == batch.y.type_as(pred)
            accuracy_temp = torch.sum(pred_correct)/len(pred_correct)

            accuracy += accuracy_temp.item()

            val_logit = val_criterion(pred.squeeze(),
                                      batch.y.type_as(pred)).item()
            val_loss += val_logit

    val_loss /= len(test_dataloader)
    accuracy /= len(test_dataloader)

    return val_loss, accuracy

(I’ve tried with clip_val = 0.65 and clip_val = None)

I hope this is enough code, and I am happy to share more if it can be useful.

Thank you for the code.

Could you tell me the following?

  • What is the meaning of the term accuracy_temp? What do you intend this term to contain?
  • Why does it make sense to accumulate accuracy_temp values across batches by adding them?
  • Why does it make sense to divide the summed-up accuracy_temp values by the number of things in dataloader?

For instance: suppose your net predicted half of the targets correctly in every batch. What would computing accuracy in this manner give you as the final accuracy value? Does this match the accuracy value that you expect to get, if your net was good at predicting exactly half of the targets in each batch correctly?

I have the same questions about logit as well: why does it make sense to add them up over batches, and why does it make sense to divide this sum by the total number of items in the input?

Thank you a lot for taking the time to help me.

My idea was to use accuracy_temp to compute the average accuracy over each batch and then sum over all the computed accuracy_temp and divide by the number of batches (which I thought to be equal to len(dataloader)) to compute the overall accuracy.

I realize now that this is correct only if the number of elements in each batch is constant, which is not necessarily true. I will change the code in order to count the total number of correctly predicted samples along all batches, and divide by the total number of samples only after the for loop.

Regarding your question, I think that if the network predicted half of the elements in every batch correctly, I would get 0.5 as accuracy as expected, but maybe the code is not doing what I thought…

With respect to the loss function, I applied a similar reasoning: I wanted to compute the loss for each element and divide it by the number of elements in the dataset, and since I am working with batches, I thought of averaging it over the number of batches. What would be the common way of computing the loss value?

I hope this answer your questions.

You can easily check this assumption by working out the numbers by hand for a simplified scenario: there are a 1000 inputs in total, which are split into 2 equal-sized batches with a 500 elements in each batch, and in each batch the network predicts half of the values correctly. What is the accuracy that will be reported in this case, if you follow your scheme? Does this match what you expected it to be? Does this correctly reflect how good the network is?

Yes, this is what I did: with respect to your example the expected accuracy is 0.5 and the accuracy in each batch is 0.5. Summing over the accuracy computed in each batch and dividing by the number of batches we would get (0.5+0.5)/2 = 0.5. Is this different from what you would expect?

Regarding the loss function, is it correct that you suggested that there is something wrong with respect to the computation of the logits?

My mistake, I assumed that len(dataloader) denotes the number of items in the underlying dataset. If len(dataloader) denotes the number of batches delivered by the dataloader then this computation is correct, except in the case—which you pointed out—when batches have different sizes.