Autocast and torch.no_grad() unexpected behaviour

import torch
from torch.cuda.amp import autocast

net = torch.nn.Conv2d(3,3,3,3).to('cuda')
input = torch.rand([3,3,5,5],device='cuda')

with autocast():
    with torch.no_grad():
        y = net(input)

    z = net(input)
    print('z {}'.format(z.requires_grad))

This will give false. Is this by design?

1 Like

Thanks for reporting this issue. I can reproduce it and it looks like a bug, as autograd seems to keep the gradients disabled even after leaving the no_grad() block:

with autocast(enabled=True):
    with torch.no_grad():
        y = net(input)

    z = net(input)
    print('z {}'.format(z.requires_grad))
> False

with autocast(enabled=False):
    with torch.no_grad():
        y = net(input)

    z = net(input)
    print('z {}'.format(z.requires_grad))
> True

@mcarilli Have you seen this behavior before?

I haven’t seen this behavior before but I know why it’s happening. Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)

The behavior you observe happens because you do both a no_grad forward pass and a grad-enabled forward pass within the same autocast context. In the no_grad forward pass, FP16 param copies are created and cached. Because it’s a no_grad context, when these FP16 copies are created they have requires_grad=False. When you run net(input) again in a grad-exposed way, you are still within the same autocast context, so the cache is live and the FP16 copies are not recreated (instead, net's FP16list ops use the cached copies). Since these cached copies have requires_grad=False, net(input) does not build an autograd graph, and z ends up having requires_grad=False.

You can restore expected behavior by exiting the autocast context before running the second forward pass:

with autocast(enabled=True):
    with torch.no_grad():
        y = net(input)
z = net(input)
print('z {}'.format(z.requires_grad)) # will print True

or by using a new autocast context for the second forward pass:

with autocast(enabled=True):
    with torch.no_grad():
        y = net(input)
with autocast(enabled=True):
    z = net(input)
    print('z {}'.format(z.requires_grad)) # will print True

I’m hesitant to call this a bug per se. The API does recommend to use a separate invocation of autocast for every forward pass. I suppose you could run into trouble if one section of a given forward pass ran under no_grad, another section was autograd-exposed, and a particular FP32 param was used in both sections, but that seems outlandish. What do you guys think?

2 Likes

Thank you for your explanation and advice. However I found this behavior unexpected. A model may be called several times in one forward pass (and possibly be called by non-user defined functions), and gradients may not always be needed in all calls, e.g., in semi-supervised learning. I think in complicated cases it will be difficult to use separate autocast to handle caches correctly.

Is it possible for the fp16 and fp32 tensors which belong to the “same” model variable to share requires_grad information? Or maybe the no_grad() context can be extended to take care of caches automatically. I think the latter maybe more practical, since there are many cases where caches may cause problems.