Hi !
I had an issue while using torch.cuda.amp.autocast()
. In the context manager I called torch.no_grad()
and it blocked gradient computation in the rest of the context manager.
Example on how to reproduce :
import torch
import torch.nn as nn
net = nn.Linear(100, 10).cuda()
input = torch.randn(32, 100, device='cuda')
target = torch.randn(32,10, device='cuda')
crit = nn.MSELoss()
with torch.cuda.amp.autocast():
with torch.no_grad():
out = net(input)
new_out = net(input)
loss = crit(new_out, target)
Then loss
will not have a grad_fn
But if we would not use autocast :
import torch
import torch.nn as nn
net = nn.Linear(100, 10).cuda()
input = torch.randn(32, 100, device='cuda')
target = torch.randn(32,10, device='cuda')
crit = nn.MSELoss()
with torch.cuda.amp.autocast(False):
with torch.no_grad():
out = net(input)
new_out = net(input)
loss = crit(new_out, target)
Then loss
has a grad_fn
.
Is it an expected behavior ?
Thanks for any help !