Set_grad_enabled (True/False) for train/val predictions

In the tutorial (Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 1.11.0+cu102 documentation), why do we set set_grad_enabled when we predict in train? When phase == “val”, the grad is disabled. My question is does the statement “_, preds = torch.max(outputs, 1)” generate different results mathematically or logically? Why do I need to disable or enable grad when predicting in train/val phases?

with torch.set_grad_enabled(phase == 'train'):
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    loss = criterion(outputs, labels)

    # backward + optimize only if in training phase
    if phase == 'train':
        loss.backward()
        optimizer.step()

You would disable the gradient computation during the validation phase to save memory by avoiding storing the intermediate forward activations, which would be needed to compute the gradients.
A more aggressive optimization might be achieved via with torch.inference_mode() which would disable the view tracking and version counter bumps additionally.

I am not sure if it makes sense to disable grad when pred results in train to safe memory? I mean during the train session, only disable grad when pred. Does it impact train pred results?

Do you mean something like this?

outputs = model(inputs)
with torch.no_grad():
    _, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

I.e. wrapping the max operation only into the no_grad guard?
If so, then I would probably .detach() the output instead as the code looks a bit cleaner to me, but the no_grad guard should also work.
The pred output will not be changed as explained before, as disabling the grad calculation will only save resources.

Thanks, I mean the following logic. Since torch.no_grad() is to save memory, I was wondering if it makes sense to save extra memory in train.

mode.train()
outputs = model(inputs)
model.eval()
with torch.no_grad():
    _, preds = torch.max(outputs, 1)
model.train()
loss = criterion(outputs, labels)

Your code snippet fits my example, so yes, you can detach() the tensor or use the guard.