Manually modify gradients of two models, average them and put them back in both models!

I am trying to train two models on two mutually exclusive portions of a datasets.
Now while I am training both the models, I want to manually extract the Gradients from Model A and Model B, after forward propagation, then before updating the weights, I want to average both the model’s gradients and put the average of the models in both the models, and update the weights.

I am trying it in the following way:

def mergeGrad(modelA, modelB, i):
    listGradA = [] 
    listGradB = [] 
    itr = 0
    for pA, pB in zip(modelA.parameters(), modelB.parameters()):
        avg = (pA.grad + pB.grad)/2
        pA.grad = avg
        pB.grad = avg
        itr += 1

In the optimization, portion I am doing

        print(gradientA, gradientB)
        mergeGrad(modelA, modelB, i)
        print(gradientA, gradientB)

Now, I am printing before and after averaging and putting the Gradient back. However, only in the very first iteration the first print statement outputs separate gradients for modelA and modelB, and right after merging them, I get the averaged gradients as I want.

However, from the second iteration onward, both the models are always generating the same gradients after forward propagation (and hence the same averaged values). How is this possible?
The models are seeing different data!

Is the manual gradient modification that I am doing is wrong?
Please do suggest if there’s a better way to manually modify and put them back into the model!
Thanks a lot!


I would bet the issue is with:

        pA.grad = avg
        pB.grad = avg

Here you set the same Tensor to be the gradient for both parameters. So the backward passes will accumulate in the Same Tensor.
If you don’t want that, you need to add a .clone() for at least one of them.

Note that it would be a neat way to implement what you want as well: make the .grad fields the same so the two backward passes actually accumulate in the same Tensor. And divide your learning rate by 2 so that you still do the same update (assuming simple SGD).

1 Like

Thanks a lot. grad.clone() solved the issue. Also appreciate the neat trick you mentioned!