Typically I use the with torch.no_grad()
context in a evaluate function like the following example:
def evaluate(model: nn.Module, val_loader):
model.eval()
val_loss = 0
correct = 0
criterion = nn.BCEWithLogitsLoss()
with torch.no_grad():
for images_1, images_2, targets in val_loader:
images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device)
with autocast(enabled=True):
outputs = model(images_1, images_2).squeeze()
val_loss += criterion(outputs, targets).sum().item() # sum up batch loss
pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability
correct += pred.eq(targets.view_as(pred)).sum().item()
val_loss /= len(val_loader.dataset)
print(f'Valid set: Loss: {val_loss:.4f}, Accuracy: {100.0 * correct / len(val_loader.dataset):.1f}%')
Now I know we can use the @torch.no_grad()
decorator above a function to act like using with torch.no_grad()
context at the beginning and for the all function body (ref: pytorch - What is the difference between '''@torch.no_grad()''' and '''with torch.no_grad()''' - Stack Overflow).
So, I want to know for the above typical example, will it have difference if I use @torch.no_grad()
decorator to replace with torch.no_grad()
context? That is, the whole function body is in no_grad mode including the model.eval()
line and others.
My motivation is using @torch.no_grad()
decorator to improve the readability (less indention in the function body).