Understanding training

Hi everyone! I have a quick question regarding the tutorial posted on pytorch for computer vision training, specifically at this link: Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation. In the train model function, I don’t exactly understand why the output is not detached during training (see snippet below). Wouldn’t outputs from the model still be attached to the computational graph when to get preds below? Another question with the code is I don’t believe the function moves the output to cpu during the validation phase when repeating the calculation to obtain predictions - this is something I’ve typically seen done. Is that not standard practice? Am I missing something here or is that all correct? Thanks for the help!

with torch.set_grad_enabled(phase == ‘train’):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

If you detach the output during training, no gradients would be calculated for the trainable model parameters and the model would thus not train at all.

If you are concerned about detaching the output during validation: you wouldn’t need to detach the tensor manually as the validation phase disables the gradient calculation via:

with torch.set_grad_enabled(phase == ‘train’):

No, I don’t believe it’s standard practice. Where have you seen it and would you know what the reason would be?

Thanks for the reply! I think the confusion comes from examples I’ve looked at like the code snippet below. Note in it the line out = torch.argmax(outputs.detach(),dim=1) used during training. I was confused why this was used here but not in the training loop I referenced in the pytorch tutorial (wasn’t sure if I wasn’t understanding something). Similarly note at the bottom of this training block the line out = model(inputs.cuda()).cpu(), which is done during validation. Any thoughts on why these steps are required in this training block as opposed to the one on pytorch’s site?

for epoch in range(5):  
model.train(True)
running_loss = 0.0
running_acc = 0
for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.cuda(),labels.cuda()

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item() * inputs.size(0)
    out = torch.argmax(outputs.detach(),dim=1)
    assert out.shape==labels.shape
    running_acc += (labels==out).sum().item()
print(f"Train loss {epoch+1}: {running_loss/len(trainset)},Train Acc:{running_acc*100/len(trainset)}%")

correct = 0
model.train(False)
with torch.no_grad():
    for inputs,labels in valloader:
        out = model(inputs.cuda()).cpu()
        out = torch.argmax(out,dim=1)
        acc = (out==labels).sum().item()
        correct += acc
print(f"Val accuracy:{correct*100/len(valset)}%")
if correct>best_val_acc:
    best_val_acc = correct
    best_val_model = deepcopy(model.state_dict())
lr_scheduler.step()

torch.argmax(outputs.detach(), dim=1) is used to get the predicted class indices in a multi-class classification use case. Detaching the tensor is needed if you are storing the calculated accuracy e.g. in a list or are accumulating it as otherwise you would store the entire computation graph with it, which would increase the memory usage (and could then yield an out of memory error). Alternatively, you could also wrap the accuracy calculation into a with torch.no_grad() block, but it depends on your coding style I guess.

This isn’t needed since the acc value is created via item(), which will create a Python literal on the CPU.

Awesome thanks so much for all of your time and help, it’s helped clarify quite a bit! One follow up to make sure I’ve got this - is the use of argmax here distinct from the use of torch.max in the Pytorch tutorial (ie in the line: _, preds = torch.max(outputs, 1), see the entire training epoch below). I’m wondering because the Pytorch version uses those indices returned by torch.max to calculate accuracies but never detaches the output in torch.max. Wouldn’t that similarly store the computation graph or does torch.max behave differently? And sorry but another follow up - from what I understand, the Pytorch tutorial’s use of torch.max should extend to the multi-class classification case as well, correct? Since torch.max returns a tuple of the max and indices for those values, and I believe that function only uses the indices themselves, I think it should work with multi-class labels but definitely let me know if I’m mistaken on that. I was hoping to model my own training loop (which requires at least 3 classification categories) using the Pytorch approach of torch.max. Appreciate your help and sorry for the follow ups - just making sure my understanding is accurate.

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

Yes, you are exactly right. torch.max would return the values with their corresponding indices, while torch.argmax only returns the indices which correspond to the max. values.
As you can see in the tutorial, the values are not used _, preds = torch.max(outputs, 1) to it would be equal to preds = torch.argmax(outputs, 1).

That’s also a great question! You are generally right and should be careful about detaching the tensors. However, since the returned indices from torch.argmax are not differentiable, the computation graph won’t be stored. If you are using torch.max, note that the values will be attached to the computation graph while the indices also won’t be attached to it. You can check it via:

output = torch.randn(10, 10, requires_grad=True)

val, idx = torch.max(output, dim=1)

# a valid grad_fn shows that this tensor is attached to the computation graph
print(val.grad_fn)
> <MaxBackward0 object at 0x7fe084afbbb0> 

print(idx.grad_fn)
> None

idx = torch.argmax(output, dim=1)
print(idx.grad_fn)
> None