A layer whose weight is the sum of other two weights, but gradient is None

I want to create a linear layer, whose weight is the weighted sum of other two linear layers. The weighted factor is denoted as alpha, i.e.

outputs = [alpha*W1+(1-alpha)*W2] * inputs

I want the gradient of loss wrt W1, W2 and also alpha, but after loss.backward(), both gradients are None. It seems that W1, W2 or alpha are all not used in the computational graph. Is there a way to incorporate these into the graph?

import torch
import torch.nn as nn

fc1 = nn.Linear(4, 6, bias=False)
fc2 = nn.Linear(4, 6, bias=False)
fc_sum = nn.Linear(4, 6, bias=False)
alpha = torch.tensor(0.1, requires_grad=True)
fc_sum.weight = nn.Parameter(fc1.weight.data * alpha + fc2.weight.data * (1 - alpha))

inputs = torch.rand([1, 4])
outputs = fc_sum(inputs)
loss = torch.sum(outputs ** 2)
loss.backward()

print(alpha.grad)
print(fc1.weight.grad)

Hi Wang!

The simplest approach would simply be:

outputs = alpha * fc1 (inputs) + (1 - alpha) * fc2 (inputs)

Calling loss.backward() will properly populate grad for fc1.weight,
fc2.weight, and alpha.

There is no need to package the weighted combination of fc1 and fc2
as a separate Linear.

(If the simple approach doesn’t work for you, could you explain your use
case in greater detail?)

Best.

K. Frank

Thanks Frank.

However, the example I provide is only the simplest form. What I really want to do is to parametrize every weight in a deep neural network (conv weight, linear weight, bn weight etc) with the parameter alpha. Also, the number of Wi may be greater than 2, and the interpolation method could be much more complicated than linear combination. For example, rather than

W=alpha*W1+(1-alpha)*W2

I would like

W=f1(alpha)*W1+f2(alpha)*W2+...+fk(alpha)*Wk

W can be weight not only from linear layers, but also from conv layers and so on. I wonder if it’s still possible to explicitly write out the formula in this situation. Thus I would like to know whether we can deal with the layer weight directly.

I think I have come up with a solution, with the help of nn.functional

class LinearSum(nn.Module):
    def __init__(self):
        super(LinearSum, self).__init__()
        self.weight1 = nn.Parameter(torch.randn(6, 4))
        self.weight2 = nn.Parameter(torch.randn(6, 4))

    def forward(self, x, a):
        weight = self.weight1 * a + self.weight2 * (1 - a)
        return F.linear(x, weight, bias=None)


fc_sum = LinearSum()
alpha = torch.tensor(0.1, requires_grad=True)
inputs = torch.rand([1, 4])
outputs = fc_sum(inputs, alpha)
loss = torch.sum(outputs ** 2)
loss.backward()

print(alpha.grad)
print(fc_sum.weight1.grad)

Any other implementation or suggestion is welcomed!