You could most likely copy the gradients to a2net
and apply the optimizer afterwards to update the model.
Here is a small example of this workflow:
model1 = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
model2 = copy.deepcopy(model1)
# Compare parameters
for param1, param2 in zip(model1.parameters(), model2.parameters()):
if (param1!=param2).any():
raise RuntimeError('Param mismatch')
# create optimizer for model2
optimizer = torch.optim.SGD(model2.parameters(), lr=1e-3)
# calculate gradients in model1
model1(torch.randn(1, 10)).backward()
for name, param in model1.named_parameters():
print(name, param.grad.norm())
# Copy gradients to model2
for param1, param2 in zip(model1.parameters(), model2.parameters()):
param2.grad = param1.grad
# Update model2
optimizer.step()
# Compare again
# Compare parameters
for param1, param2 in zip(model1.parameters(), model2.parameters()):
print((param1 - param2).abs().sum())
Let me know, if that would work.