Manually start Autograd of model using gradient

Hi all,

I’m trying to back prop a model using the gradients acquired in the front and backprop of a copy of the exact same model. I acquire the gradient using a hook:

# define the models
class ANet(nn.Module): ...

class BNet(nn.Module):
    def __init__(self):
        ... layers ...
        self.grad = None

    def forward(self, x):
        x.register_hook(lambda g: self.grad = g)
        ... layers ...
        return x

Running this model yields the proper gradient:

# instantiate models
a1net = ANet()
a2net = deepcopy(a1net)
bnet = BNet()

# forward and back prop a1 and b
X, y = next(iter(train_loader))
yhat = bnet(a1net(X))
loss = F.nll_loss(yhat, y)

Afterwards, I’d like to update a2net using the extracted gradient. I usually just call .backward() to initialize back propagation using Autograd, but now I’ve got to apply the gradient acquired in the previous process.

# backprop a2, this is where I'm stuck
grad = bnet.grad


I’ve thought of using torch.autograd.backward(tensors=a1net, grad_tensors=grad), but that didn’t work.

Comparing my extracted gradient with the result of a1net(X) shows that my grad is lacking the reference to the model in grad._grad_fn, also I wouldn’t know how to create a grad explicitly for a non scalar.

Most related questions I see around here are about retaining graphs, but always about a single model.


You could most likely copy the gradients to a2net and apply the optimizer afterwards to update the model.
Here is a small example of this workflow:

model1 = nn.Sequential(
    nn.Linear(10, 10),
    nn.Linear(10, 1)
model2 = copy.deepcopy(model1)

# Compare parameters
for param1, param2 in zip(model1.parameters(), model2.parameters()):
    if (param1!=param2).any():
        raise RuntimeError('Param mismatch')

# create optimizer for model2
optimizer = torch.optim.SGD(model2.parameters(), lr=1e-3)

# calculate gradients in model1
model1(torch.randn(1, 10)).backward()
for name, param in model1.named_parameters():
    print(name, param.grad.norm())

# Copy gradients to model2
for param1, param2 in zip(model1.parameters(), model2.parameters()):
    param2.grad = param1.grad

# Update model2

# Compare again
# Compare parameters
for param1, param2 in zip(model1.parameters(), model2.parameters()):
    print((param1 - param2).abs().sum())

Let me know, if that would work.