I test if it is allowed to update another module in a backward hook with the following code:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam
x = Variable(torch.randn(3, 4), requires_grad=True)
x2 = Variable(torch.randn(3, 4), requires_grad=True)
y = nn.Linear(4, 5)
a = (3 * x).sum()
b = y(x2)
oa = Adam([x])
ob = Adam(y.parameters())
print('Before update', y)
def update(grad):
loss = b.sum()
loss.backward()
ob.step()
print('After update', y)
a.register_hook(update)
a.backward(retain_variables=True)
As you can see, what update
function does is irrelavent to Variable a
and x
, but I found out that the process hung when calling loss.backward()
. Is it allowed to calling backward in backward hooks?