Hi, I have a problem about how to tying weights between different neurons. Suppose we have a weight matrix w with 10 x 10, and the weights in w[1,:] and w[2,:] is same, and equal to w_0. When training this tiny model, I want to update w_0 instead of updating w[1,:] and w[2,:] separately. The gradient is given by g(w_0) = (g(w[1,:])+g(w[2,:]))/2.

I am not sure how to perform these operations properly. Thanks in advance for any help.

def weights_sharing(weights, weights_grad, group_info):
#share weights
group_len = group_info.size(0)
weights_size = weights.size(1)
weights_collection = weights.gather(0,group_info.expand(weights_size,group_len).transpose(0,1))
averge_weight = weights_collection.mean(dim=0)
for i in group_info.numpy():
weights[i] = averge_weight
#share gradient
grad_collection = weights_grad.gather(0, group_info.expand(weights_size, group_len).transpose(0,1))
averge_grad = grad_collection.mean(dim=0)
for i in group_info.numpy():
weights_grad[i] = averge_grad

sample usage:

linear = nn.Linear(10,5)
x=torch.randn(5,10)
y=torch.randn(5,5)
y_c = linear(x)
loss = (y_c-y).pow(2).mean()
loss.backward()
weight_sharing(linear.weight.data, linear.weight.grad.data, group_info)
#then update the parameter as usual

I still have a question, since the weight sharing happens after backward pass, I think it won’t affect computation graph. I think this is true, but not sure about it.

Ah, sorry, I misunderstood.
I’d probably do something like

w_raw = nn.Parameter(9, 10)

in __init__ and then in forward:

w = torch.cat([w_raw[:1], w_raw], 0)

Here is a little demo for what this does:

w_raw = torch.randn(9, 10, requires_grad=True)
w = torch.cat([w_raw[:1], w_raw], 0)
w.retain_grad()
og = torch.randn(10, 10)
w.backward(og)
print (w.grad, w_raw.grad)

You see that w_raw.grad[0] is the sum of the first two rows in w.grad.

If you really need the average rather than the sum, you could multiply the grad with 0.5 before processing it further (or use a backward hook), but I’d mostly expect the sum to work for you in most use cases…