Thanks for reporting this issue. I can reproduce it and it looks like a bug, as autograd
seems to keep the gradients disabled even after leaving the no_grad()
block:
with autocast(enabled=True):
with torch.no_grad():
y = net(input)
z = net(input)
print('z {}'.format(z.requires_grad))
> False
with autocast(enabled=False):
with torch.no_grad():
y = net(input)
z = net(input)
print('z {}'.format(z.requires_grad))
> True
@mcarilli Have you seen this behavior before?