My network's weights get updated despite using torch.no_grad()

Hi everyone :slight_smile:

I want to implement an early stopping mechanic to my DNN. Thus, I have to calculate the validation loss for each epoch.
I use @torch.no_grad() during the validation loss calculation to avoid any gradient computation that would alter the training.

Here is my training loop:

def main():
    training_dataset = torch.load("dataset/training_dataset.pth")

    net = dnn.Net()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)

    early_stopping = EarlyStopping(tolerance=5, min_delta=10)
    validation_dataset = torch.load("dataset/validation_dataset.pth")

    n_epochs = 150
    for epoch in tqdm.tqdm(range(n_epochs)):
        epoch_train_loss = train_epoch(training_dataset, net, criterion, optimizer)
        with torch.no_grad():
            epoch_validate_loss, _ = validation.validate(
                validation_dataset, net, criterion
            )

Here is the validation function:

@torch.no_grad()
def validate(
    validation_dataset: torch.utils.data.Dataset, net: nn.Module, criterion: nn.Module
):
    net.eval()
    validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1)

    correct = 0.0
    loss = 0

    for inputs, target in validation_loader:
        output = net(inputs)
        loss = criterion(output, target)

        _, pred = output.max(1)
        correct += (pred == target).sum()

    accuracy = correct / len(validation_dataset) * 100.0
    return loss, accuracy

NB: I know that using both @torch.no_grad and with torch.no_grad() is overkill

I run the training on 150 epochs.
After the training, I run calculate the accuracy on the validation set.

Here is what I obtain when:

  • NOT calculating validation loss for each epoch: 91.67%
  • Calculating validation loss for each epoch: 97.22%

It then seems that calculating the validation loss for each epoch has definitely alters training by leaking validation data.

Can you tell me what I am doing wrong here?

Not necessarily, as you could also observe noise creating by additional calls to the pseudorandom number generator. You could check how sensitive your training is to the seed.

Oh sure I did not think of it! Socan I safely conclude that there is nothing wrong here?

You could run additional tests using different seeds to estimate the mean and stddev of e.g. the accuracy or loss of your current training routine and compare it to the currently observed difference. You could also double check that weights are never updated explicitly by comparing (some) parameters before and after the validation run, just to make sure.

I’ve explicitely checked the model’s weights and they are not modified after the validation run. Thank you for your time and your answer!