Hi all,
I’m trying to back prop a model using the gradients acquired in the front and backprop of a copy of the exact same model. I acquire the gradient using a hook:
# define the models
class ANet(nn.Module): ...
class BNet(nn.Module):
def __init__(self):
... layers ...
self.grad = None
def forward(self, x):
x.register_hook(lambda g: self.grad = g)
... layers ...
return x
Running this model yields the proper gradient:
# instantiate models
a1net = ANet()
a2net = deepcopy(a1net)
bnet = BNet()
# forward and back prop a1 and b
X, y = next(iter(train_loader))
yhat = bnet(a1net(X))
loss = F.nll_loss(yhat, y)
loss.backward()
Afterwards, I’d like to update a2net
using the extracted gradient. I usually just call .backward()
to initialize back propagation using Autograd, but now I’ve got to apply the gradient acquired in the previous process.
# backprop a2, this is where I'm stuck
grad = bnet.grad
a2net.backward(grad)...?
I’ve thought of using torch.autograd.backward(tensors=a1net, grad_tensors=grad)
, but that didn’t work.
Comparing my extracted gradient with the result of a1net(X)
shows that my grad is lacking the reference to the model in grad._grad_fn
, also I wouldn’t know how to create a grad explicitly for a non scalar.
Most related questions I see around here are about retaining graphs, but always about a single model.
Best,