A way to detect whether code is executed within "with torch.no_grad():"

Hello,

I wanted to ask whether there is a way to check if code is executed within a with torch.no_grad(): block?
I have the following problem: I am evaluating my model within with torch.no_grad():. However, my model executes loss.backward() in its forward pass since I want to adapt my deployed model on the unlabelled test data. After adapting my model, I want to evaluate it on the same data using my test function but I am getting the following error: “RuntimError: element 0 of tensors does not require grad and does not have a grad_fn”.

I assume this error is thrown because my model is calling loss.backward() within with torch.no_grad():.

Therefore, I need something like this

#Basic structure of my test function
def test(model, dataloader):
    with torch.no_grad():
         for data, label in dataloader:
             out = model(data) # this is calling my forward function of my model
             ...

#Custom Forward function of my model
def forward_mymodel(self, input):
     out = model(x)
     if not TODO:InTorchNoGrad:
        loss = softmax_entropy(out).mean(0)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()      

I guess, I could just check one of my input tensors and see if it requires grad but I was wondering if there is an elegant solution for this?

I would appreciate if someone could help me.

Thank you! :slight_smile:

torch.is_grad_enabled() should work:

print(torch.is_grad_enabled())
# True

with torch.no_grad():
    print(torch.is_grad_enabled())
    # False
    
with torch.inference_mode():
    print(torch.is_grad_enabled())
    # False
3 Likes

Thank you so much for your fast response! :partying_face:

It works! :slight_smile: