Hi,
I wondered if it makes a difference to specify torch.no_grad() inside a custom autograd class for the forward and backward functions.
Thanks!
No, it shouldn’t make a difference since gradient computation is already disabled in custom autograd.Function
s by default:
class MyFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
print(torch.is_grad_enabled())
ctx.save_for_backward(input)
return input * 2
@staticmethod
def backward(ctx, grad_output):
print(torch.is_grad_enabled())
input, = ctx.saved_tensors
return grad_output * input
f = MyFun().apply
x = torch.randn(1, 1, requires_grad=True)
out = f(x)
# False
out.mean().backward()
# False
1 Like