Tying weights from different neurons and averaging the gradients

(Gaosh) #1

Hi, I have a problem about how to tying weights between different neurons. Suppose we have a weight matrix w with 10 x 10, and the weights in w[1,:] and w[2,:] is same, and equal to w_0. When training this tiny model, I want to update w_0 instead of updating w[1,:] and w[2,:] separately. The gradient is given by g(w_0) = (g(w[1,:])+g(w[2,:]))/2.

I am not sure how to perform these operations properly. Thanks in advance for any help.

(Gaosh) #2

I come up with a simple solution:

def weights_sharing(weights, weights_grad, group_info):
    #share weights
    group_len = group_info.size(0)
    weights_size = weights.size(1)
    weights_collection = weights.gather(0,group_info.expand(weights_size,group_len).transpose(0,1))
    averge_weight = weights_collection.mean(dim=0)
    for i in group_info.numpy():
        weights[i] = averge_weight

    #share gradient
    grad_collection = weights_grad.gather(0, group_info.expand(weights_size, group_len).transpose(0,1))
    averge_grad = grad_collection.mean(dim=0)
    for i in group_info.numpy():
        weights_grad[i] = averge_grad

sample usage:

linear = nn.Linear(10,5)
y_c = linear(x)
loss = (y_c-y).pow(2).mean()
weight_sharing(linear.weight.data, linear.weight.grad.data, group_info)
#then update the parameter as usual

I still have a question, since the weight sharing happens after backward pass, I think it won’t affect computation graph. I think this is true, but not sure about it.

Any discussion is welcome.

(Thomas V) #3

You can just assign the same Parameter to two different modules. You’ll get the sum of gradients, but that should be OK.

Best regards


(Gaosh) #4

Hi, Thomas

Are these still applied to two different rows in a weight matrix?

(Thomas V) #5

Ah, sorry, I misunderstood.
I’d probably do something like

  w_raw = nn.Parameter(9, 10)

in __init__ and then in forward:

  w = torch.cat([w_raw[:1], w_raw], 0)

Here is a little demo for what this does:

w_raw = torch.randn(9, 10, requires_grad=True)
w = torch.cat([w_raw[:1], w_raw], 0)
og = torch.randn(10, 10)
print (w.grad, w_raw.grad)

You see that w_raw.grad[0] is the sum of the first two rows in w.grad.

If you really need the average rather than the sum, you could multiply the grad with 0.5 before processing it further (or use a backward hook), but I’d mostly expect the sum to work for you in most use cases…

Best regards