Torch.cuda.amp blocks gradient computation

You would have to exit the autocast context before running the second forward pass as described here due to the internal caching.

CC @vadimkantorov as we’ve discussed more general use cases and this one might be important for you.