Reusing altered parameters in more than one layer

Hello together,
I’m trying to use a model with one parameter being used in two layers but slightly altered in the second layer. It looks something like this:

class CrNN(nn.Module):
    def __init__(self, n, m):
        self.P1 = nn.Parameter(torch.rand((m, n))) #weights
        self.P2 = nn.Parameter(torch.rand(m)) # bias

        self.Layer1=nn.Linear(n, m)
        self.Layer1.bias=self.P2 # This is the correct parameter for the bias

        self.Layer2=nn.Linear(m, n, bias=False)
        self.Layer2.weight=self.P1 # This is the correct parameter for this layer

Now I would like my net to be described by only the two parameters and giving the weights of the first layer in terms of the P1 parameter as follows:

torch.where(self.P1<0,-self.P1,0)

so I am only interested in the negated negative entries in the tensor for the first layer weights.

How can this be done. If I simply set this calculation in the initializer, I get a new parameter independent from P1.

Thanks for your help!

Hi Hannes!

I would suggest keeping your Parameters, P1 and P2, as is, but using
the functional form of linear(), rather than the class Linear.

Because torch.where() doesn’t backpropagate cleanly (assuming that
you do want to backpropagate through the altered layer1 weights), you
should use something like minimum() in place of where(). (The example
I give below backpropagates contributions to the gradient of P1 from its
use both in layer1 (in altered form) and in layer2.)

Note, there is no need to wrap P1 and P2 in Linears – as properties of
your CrNN Module, they are first-class Parameters that can be passed
to optimizers, etc.

These points are illustrated in the following script:

import torch
print (torch.__version__)

torch.manual_seed (2022)

class CrNN (torch.nn.Module):
    def __init__ (self, n, m):
        super().__init__()
        self.P1 = torch.nn.Parameter (torch.rand (m, n))  # weights
        self.P2 = torch.nn.Parameter (torch.rand (m))     # bias
    
    def forward (self, input):   # assumes input has shape [*, n], e.g., [nBatch, n]
        layer1_weight = -torch.minimum (self.P1, -torch.zeros (1))
        x = torch.nn.functional.linear (input, layer1_weight, self.P2)   # apply layer1, with bias
        x = torch.nn.functional.linear (x, self.P1.T)                    # apply layer2, no bias
        return x

nBatch = 3
n = 5
m = 10

mod = CrNN (n, m)
print (list (mod.parameters()))

input = torch.randn (nBatch, n)
output = mod (input)
print (input)
print (output)
loss = (output**2).sum()
loss.backward()
print (mod.P1.grad)

Here is its output:

1.12.0
[Parameter containing:
tensor([[0.3958, 0.9219, 0.7588, 0.3811, 0.0262],
        [0.3594, 0.7933, 0.7811, 0.4643, 0.6329],
        [0.6689, 0.2302, 0.8003, 0.7353, 0.7477],
        [0.5585, 0.6226, 0.8429, 0.6105, 0.1248],
        [0.8265, 0.2117, 0.8574, 0.4282, 0.3964],
        [0.1440, 0.0034, 0.9504, 0.2194, 0.2893],
        [0.6784, 0.4997, 0.9144, 0.2833, 0.5739],
        [0.2444, 0.2476, 0.1210, 0.6869, 0.6617],
        [0.5168, 0.9089, 0.8799, 0.6949, 0.4609],
        [0.1263, 0.6332, 0.4839, 0.7779, 0.9180]], requires_grad=True), Parameter containing:
tensor([0.0768, 0.9693, 0.2956, 0.7251, 0.5438, 0.7403, 0.3211, 0.5044, 0.6463,
        0.9245], requires_grad=True)]
tensor([[-1.3409,  0.4439, -0.5857, -0.9341, -1.2294],
        [-2.0586,  1.0720, -1.9449,  0.0958,  0.4007],
        [-1.0349, -2.1315,  1.0541, -1.7012,  0.5013]])
tensor([[2.3293, 2.9349, 4.2035, 3.1404, 3.0213],
        [2.3293, 2.9349, 4.2035, 3.1404, 3.0213],
        [2.3293, 2.9349, 4.2035, 3.1404, 3.0213]], grad_fn=<MmBackward0>)
tensor([[ 1.0731,  1.3521,  1.9365,  1.4468,  1.3919],
        [13.5460, 17.0682, 24.4456, 18.2628, 17.5704],
        [ 4.1309,  5.2050,  7.4547,  5.5693,  5.3581],
        [10.1338, 12.7687, 18.2878, 13.6625, 13.1445],
        [ 7.5995,  9.5755, 13.7143, 10.2457,  9.8573],
        [10.3460, 13.0362, 18.6708, 13.9486, 13.4198],
        [ 4.4877,  5.6546,  8.0987,  6.0504,  5.8210],
        [ 7.0495,  8.8824, 12.7217,  9.5042,  9.1438],
        [ 9.0319, 11.3804, 16.2993, 12.1769, 11.7152],
        [12.9205, 16.2801, 23.3168, 17.4196, 16.7591]])

Best.

K. Frank

1 Like

Thank you, that’s working perfectly!