Torch.cuda.amp blocks gradient computation

Hi !

I had an issue while using torch.cuda.amp.autocast(). In the context manager I called torch.no_grad() and it blocked gradient computation in the rest of the context manager.

Example on how to reproduce :

import torch
import torch.nn as nn
net = nn.Linear(100, 10).cuda()
input = torch.randn(32, 100, device='cuda')
target = torch.randn(32,10, device='cuda')
crit = nn.MSELoss()

with torch.cuda.amp.autocast():
    with torch.no_grad():
        out = net(input)

    new_out = net(input)
    loss = crit(new_out, target)

Then loss will not have a grad_fn
But if we would not use autocast :

import torch
import torch.nn as nn
net = nn.Linear(100, 10).cuda()
input = torch.randn(32, 100, device='cuda')
target = torch.randn(32,10, device='cuda')
crit = nn.MSELoss()

with torch.cuda.amp.autocast(False):
    with torch.no_grad():
        out = net(input)

    new_out = net(input)
    loss = crit(new_out, target)

Then loss has a grad_fn.

Is it an expected behavior ?

Thanks for any help !

1 Like

You would have to exit the autocast context before running the second forward pass as described here due to the internal caching.

CC @vadimkantorov as we’ve discussed more general use cases and this one might be important for you.

Hi !

Thanks for the quick reply.
Sorry I didn’t see it was a duplicate.

For this particular case how bad is to call net.requires_grad_(False) to avoid computing the activations and then reset it to true ?

Thanks

@ptrblck And I guess this usecase is also an extra reason for also supporting autocast-handling module wrappers :slight_smile: