Where should I use with torch.no_grad()?

Hello every one.
I’m trying to write a function that does both training and validation since like 90% of the code for them is the same . But I’m facing a slight problem here. I don’t know where this should be placed so the gradient is not calculated for the tensors.
Should I place with torch.no_grad() before the operations such as model(imgs), criterion(preds, labels) or
Should I be using it Like this e.g.

for imgs, labels in dataloader:
    imgs = imgs.to(device) 
    labels = labels.to(device)
    with torch._nograd():
            model.eval()
            preds = mode(imgs)
   # the rest 
   loss = criterion(preds, labels) 

or

for imgs, labels in dataloader:
    with torch._nograd():
           imgs = imgs.to(device) 
           labels = labels.to(device)
           model.eval()
           preds = mode(imgs)
   # the rest 
   loss = criterion(preds, labels)
   # acc, etc  

Both codes would work the same, if you just want to run inference and if your input doesn’t require gradients.

1 Like

Thank you very much.
So you mean simply doing sth like :

def train_validation_loop(model, dataloader, optimizer, criterion, is_training,
                          device, topk=1, interval=1000 ):
    
    preds = None 
    loss = 0.0
    loss_total = 0.0
    accuracy_total = 0.0
    total_batches = len(dataloader)
    status = 'training' if is_training else 'validation' 
    
    for i, (imgs, labels) in enumerate(dataloader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        
        if is_training:
            model.train()
            preds = model(imgs)
            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            model.eval()
            with torch.no_grad():
                preds = model(imgs)
                loss = criterion(preds, labels)

        loss_total += loss.item()
        _, class_idxs = preds.topk(k, dim=1)
        results = (class_idxs.view(*labels.shape) == labels).float()
        accuracy_total += torch.mean(results)

        if i % interval == 0: 
            _, class_idxs = preds.topk(k, dim=1)
            results = (class_idxs.view(*labels.shape) == labels).float()
            accuracy_per_batch = torch.mean(results)
            print(f'{status} loss/accuracy(per batch): {loss:.6f} / {accuracy_per_batch:.4f}')
    # calculate the loss/accuracy after each epoch        
    print(f'{status} loss/accuracy: {loss_total/total_batches:.6f} / {accuracy_total/total_batches:.4f}')

will consider gradients when in training mode and wont do this when in validation mode?
So basically