A different approach for setting the grad mode in the transfer learning tutorial

In the transfer learning tutorial the function for training the model (train_model) performs with torch.set_grad_enabled(phase == 'train') for the calculation of outputs and loss for every mini-batch. I am wondering if an alternative approach would be to perform _ = torch.set_grad_enabled(phase == "train") before the loop for the data loader. That way the grad will be enabled or disabled one time per epoch, avoiding the need to run with torch.set_grad_enabled(phase == 'train') for every iteration inside an epoch. I don’t think this will have a noticeable effect on performance, but my question is if this is still valid and will produce the same results compared to the approach used in the documentation.

For reference, here is what the train_model function would look like (I commented out the docs approach):

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                ########## Approach I want to use:
                # Set grad depending on phase (performed once per epoch)
                _ = torch.set_grad_enabled(phase == "train")
                ########## -------------------------

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()
                   

                    ########## Approach in documentation:
                    # forward
                    # track history if only in train
                    #with torch.set_grad_enabled(phase == 'train'):
                    #    outputs = model(inputs)
                    #    _, preds = torch.max(outputs, 1)
                    #    loss = criterion(outputs, labels)
                    ########## ------------------------------------

                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                   # .... rest of the function ...

Thanks!

Yes, enabling or disabling the gradient calculation globally should also work unless you depend on the default behavior (enabled gradient calculation) somewhere else in the code, which I don’t think is the case.

1 Like