Tied weights within one linear layer

Hi,

I’m relatively new to Pytorch and am trying to play with a toy-demo to understand the autograd functionality better. I’m trying to create a custom nn.module which is similar to the Linear layer except the outgoing weights are tied (in a per-input way). E.g. suppose I have input x = [x1 x2 x3 x4] and 4 hidden units. I want the weight matrix to look like:

W = [ w1 w1 w1 w1
  w2 w2 w2 w2
  w3 w3 w3 w3
  w4 w4 w4 w4]

That is, I effectively only want to store 4 weights and 4 biases in this layer [w1 w2 w3 w4] and [b1 b2 b3 b4]. But during the computation of forward() and backward(), I’m hoping the stored weights can be expanded into the full matrix W and hiddens can be calculated as:

h = x * W

I’m in particular unsure how to implement the “storing [w1 w2 w3 w4] and expanding into full W matrix” part. I’m not sure if autograd will send the gradient information back through all 4x4=16 “effective weights” at computation time. I understand this model isn’t very useful but I’m doing it to understand how to manipulate the nn.modules at a finer level so I can do more useful things in the future.

Here is an attempt at a solution:

class tiedLinear(nn.Module):
    def __init__(self, in_features, bias=True):
        super(tiedLinear, self).__init__()
        self.in_features = in_features
        self.weight = Parameter(torch.Tensor(in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        repeated_weight = None
        for i in range(0, self.in_features):
            row = self.repeat(self.weight[i], self.in_features)
            repeated_weight = repeated_weight.cat(repeated_weight, row, 0)

        return F.linear(input, self.repeated_weight, self.bias)

    def repeat(self, w, total):
        return torch.mul(torch.ones(total), w)

A nicer way to write the repeated_weight is repeated_weight = self.weight.repeat(self.in_features, 1)

Autograd will compute gradients with respect to the in_features weights. For example, you can do the following: (I took your code and modified it a little)

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable
import torch.nn.functional as F
import math

class tiedLinear(nn.Module):
    def __init__(self, in_features, bias=True):
        super(tiedLinear, self).__init__()
        self.in_features = in_features
        self.weight = Parameter(torch.Tensor(in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(in_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        repeated_weight = self.weight.repeat(self.in_features, 1)
        return F.linear(input, repeated_weight, self.bias)

layer = tiedLinear(4, False)
x = Variable(torch.randn(4))
out = sum(layer(x))
out.backward()
params = next(layer.parameters())
print(x)
print(params.grad)  # will be 4 times x because of the repeated weights. Autograd has taken the repetition into account

In general autograd keeps track of changes to Variables (and Parameters, which are Variables and uses that to compute gradients. You can learn more about it here: http://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
Does that answer your question?

I ran the code and everything looks good. Let me see if I understand you correctly,

self.weight.repeat(self.in_features, 1) repeats each weight, say w1, 4 times. Autograd is smart enough to calculate dw1 by considering all activations affected by the copies of w1.

If this is correct, then you answered the question, amazing thanks! It seems to me this is non-trivial behaviour (I don’t know if you can guess this behaviour with the .repeat() method from looking at the tutorial where they only use simple operations like *).

p.s. after thinking about it a bit, I think it makes sense. repeated_weight either stores the python id of the self.weight object 4 times or maybe each element of the repeated_weight matrix has its own python id. Autograd figures out the contribution of the weights to the loss (out) through their python id?

Yes, that’s what I mean. Your explanation of how autograd works isn’t quite right.

Here’s some more detail into how autograd does that. The repeat function is implemented here and it has a forward() and backward() function. Autograd will call the backward function when computing gradients: repeat’s backward takes into account how the forward pass repeated the data.

Ah, that makes more sense. People have already implemented repeat as a function with its own forward() and backward(). So this means that if I tried to use my original code, it wouldn’t work because it doesn’t know how to backward()?

Autograd will also work with your original code. The cat and repeat functions both have a backward() implemented somewhere and autograd will call those when computing gradients. Most functions that you can apply to a Variable have a backward somewhere.

1 Like

Adding on to what Richard said, I think the implementation would be more efficient if you use .expand as that doesn’t copy the weights 4 times.

Do you have the link to the documentation? Are you saying .expand “makes a promise” to do the calculation 4 times with the one vector (instead of expanding into a matrix with 4 columns of the same vector)?

http://pytorch.org/docs/0.2.0/tensors.html#torch.Tensor.expand It essentially just change one stride to 0. So it is effectively a matrix with four tensor repeated, but the underlying storage is the same vector. And of course autograd handles it as well. See https://github.com/pytorch/pytorch/blob/master/torch/autograd/_functions/tensor.py#L141.

1 Like