Why do we need if phase == 'train': inside of with torch.set_grad_enabled(phase == 'train'):

Hello guys,
using the Pytorch tutorial, I wrote the following for training and validation:

for epoch in range(num_epochs):  
    # Each epoch consists of a training-phase and a validation-phase
    for phase in ["train", "val"]:
        if phase == 'train':
            # We set our model to Training-mode and choose the train-dataloader
            model.train()
            dataloader = train_dataloader
        else:
            # We set our model to Validation-mode and choose the validation-dataloader
            model.eval()
            dataloader = val_dataloader

        for iter_num, (x_for_phase, y_for_phase) in enumerate(dataloader):

            """
            Take the first Batch "x_for_phase" and input this to the model. The model computes all
            output-values "prediction" which we compare with our labels for the current batch "y_for_phase".
            """
            prediction = model(x_for_phase)
            
            # Apply the Loss-Function
            loss = loss_func(prediction, y_for_phase)

            # Delete the gradients
            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                if phase == 'train':     # Why do we need this, isn't it contained in the line above?
                    # Compute the gradients
                    loss.backward()

                    # Update the weights
                    optimizer.step()

First I implemented it without this if phase == 'train': within the with torch.set_grad_enabled(phase == 'train'):, because I thought if we disable the gradients, then optimizer.step() doesn’t do anything. However, if I remove the train-phase in the second for loop, and only use the validation-phase, the model weights get updated and the model is improving. So, I’m wondering why we have to disable the gradients and check for the phase. Is there any application when we want to disable the gradients but update the weights anyway?


EDIT:
Would the following also be correct?

for epoch in range(num_epochs):   
    # Each epoch consists of a training-phase and a validation-phase
    for phase in ["train", "val"]:
        if phase == 'train':
            # We set our model to Training-mode and choose the train-dataloader
            model.train()
            dataloader = train_dataloader
        else:
            # We set our model to Validation-mode and choose the validation-dataloader
            model.eval()
            dataloader = val_dataloader

        for iter_num, (x_for_phase, y_for_phase) in enumerate(dataloader):

            """
            Take the first Batch "x_for_phase" and input this to the model. The model computes all
            output-values "prediction" which we compare with our labels for the current batch "y_for_phase".
            """
            prediction = model(x_for_phase)
            
            # Apply the Loss-Function
            loss = loss_func(prediction, y_for_phase)

            # Delete the gradients
            optimizer.zero_grad()

            
            if phase == 'train':     # Why do we need this, isn't it contained in the line above?
                # Compute the gradients
                loss.backward()

                # Update the weights
                optimizer.step()

So, I removed the line with torch.set_grad_enabled(phase == 'train'):, because it seems by default the gradients are enabled and due to the if phase == 'train': the gradients don’t get updated in the validation phase. The “only” bad thing in this solution is, that the computational graph also gets computed for all iterations in the validation phase, which may lead to a worse runtime. Is that right?

Thanks in advance,
Matthias

1 Like

any answer for that?

It depends on the optimizer and optimizers with running internal states (such as Adam) could update the parameters also with a zero gradient.

Yes, you would waste compute resources during the validation phase and increase the memory usage.