Where should I put train and eval

Hi,
I saw so many methods the training/validation part of a CNN and here is mine actually :

model.train()
for e in range(epoch):
    train_sum_loss=0.0
    validation_sum_loss=0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model.forward(inputs)
        batch_loss = loss(outputs.squeeze(),labels)
        batch_loss.backward()
        optimizer.step()
        train_sum_loss += batch_loss.item()

    model.eval()
    for inputs, labels in validation_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model.forward(inputs)
        batch_loss = loss(outputs.squeeze(),labels)
        validation_sum_loss += batch_loss.item()

Results seems good, the network converge, no overfitting.
But if I put the model.train() just before the for training loop (like the validation loop), the NN overfit really fastly.

Which one of these methods is the good one ? On which result I can work ?

Thanks for helping.

model.train()
for inputs, labels in train_loader:

The model.train() needs to go there. If you put it outside as in your snippet, the model will only be in training mode for the first epoch. All subsequent epochs will be in evaluation mode.

Ok thanks.

What about this instruction : torch.no_grad() ?

Some people use it, but not everyone. When I read the documention, it seems to me that this command does the same thing as model.eval(), doesn’t it ?

No. They aren’t the same. I’ve copied a reply from another user that explains the difference.

In essence, simply using torch.no_grad() will not impact layers such as dropout or batchnorm, they will still be used during inference.

Ok thanks for details.

So I can use it together during validation loop.

But I haven’t find a way to do the opposite of no_grad() (to be back in training mode).

Edit:
Maybe I’ve just found it : torch.set_grad_enabled(True)
Is it this instruction ?

Yes. I was just going to post that.

@DSX

I apologize. Just double checked the documentation. You do not need to use set_grad_enabled(). torch.no_grad() only sets the requires_grad=False temporarily.

x = torch.randn(3, requires_grad=True)

print(x.requires_grad)

print((x ** 2).requires_grad)

with torch.no_grad():

    print((x ** 2).requires_grad)

print((x ** 2).requires_grad)
True
True
False
True

Thanks for your double check.
It will simpler a bit my code :+1: