I haven’t seen this behavior before but I know why it’s happening. Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)
The behavior you observe happens because you do both a no_grad
forward pass and a grad-enabled forward pass within the same autocast
context. In the no_grad
forward pass, FP16 param copies are created and cached. Because it’s a no_grad
context, when these FP16 copies are created they have requires_grad=False
. When you run net(input)
again in a grad-exposed way, you are still within the same autocast
context, so the cache is live and the FP16 copies are not recreated (instead, net
's FP16list ops use the cached copies). Since these cached copies have requires_grad=False
, net(input)
does not build an autograd graph, and z
ends up having requires_grad=False
.
You can restore expected behavior by exiting the autocast context before running the second forward pass:
with autocast(enabled=True):
with torch.no_grad():
y = net(input)
z = net(input)
print('z {}'.format(z.requires_grad)) # will print True
or by using a new autocast context for the second forward pass:
with autocast(enabled=True):
with torch.no_grad():
y = net(input)
with autocast(enabled=True):
z = net(input)
print('z {}'.format(z.requires_grad)) # will print True
I’m hesitant to call this a bug per se. The API does recommend to use a separate invocation of autocast for every forward pass. I suppose you could run into trouble if one section of a given forward pass ran under no_grad, another section was autograd-exposed, and a particular FP32 param was used in both sections, but that seems outlandish. What do you guys think?