Parameter / Weight sharing

I’m a little lost on how it would be possible to perform weight sharing in pytorch. In my case I would like to do the following:

Essentially I would like to reuse weights and weights w2 concatenated together in a loop. are the weights for conv2d layer q as well as being used in the loop, while weights w2 would ONLY be used in the conv2d layer in the loop (however, since it would be repeated n times, technically weights w2 would be shared as well).

class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config
        self.w2 = Variable(torch.randn(10, 1, 3, 3), requires_grad=True)
        self.h = nn.Conv2d(in_channels=2, 
                           kernel_size=(3, 3), 
                           stride=1, padding=1) 
        self.r = nn.Conv2d(in_channels=150, 
                           kernel_size=(1, 1), 
                           stride=1, padding=1)
        self.q = nn.Conv2d(in_channels=1, 
                           kernel_size=(3, 3), 
                           stride=1, padding=1) 

    def forward(self, x):
        h = self.h(x)
        r = self.r(h)
        q = self.q(r)
        v = torch.max(q, dim=1)[0]
        for i in range(n):
            rv =[r, v], 1)
            Conv2d(rv, weights=[self.q.weight, self.w2]) # This is obviously not valid torch code

In Theano it could be accomplished as follows:

# Helper function
def conv2D_keep_shape(x, w, subsample=(1, 1)):
    # crop output to same size as input
    fs = T.shape(w)[2] - 1  # this is the filter size minus 1
    ims = T.shape(x)[2]     # this is the image size
    return theano.sandbox.cuda.dnn.dnn_conv(img=x,
                                            )[:, :, fs/2:ims+fs/2, fs/2:ims+fs/2]

# Weights
self.w = theano.shared((np.random.randn(1, 150, 1, 1)).astype(theano.config.floatX))
self.w1 = theano.shared((np.random.randn(10, 1, 3, 3)).astype(theano.config.floatX))
self.w2 = theano.shared((np.random.randn(10, 1, 3, 3)).astype(theano.config.floatX))

# Model
self.r = conv2D_keep_shape(input, self.w)

self.q = conv2D_keep_shape(self.r, self.w1)

self.v = T.max(self.q, axis=1, keepdims=True)

for i in range(n):
    self.q = conv2D_keep_shape(T.concatenate([self.r, self.v], axis=1), T.concatenate([self.w1, self.w2], axis=1))

Is this possible to achieve in pytorch? Please also let me know if you need more explination in case the above is not clear.

1 Like

Just thought of a possible workaround although its fairly hacky and I’m hoping there is a cleaner way to do it.

It would (at least in theory), be possible to modify the Conv2d class in /torch/nn/modules/ so that the forward function could accept “weight” as a parameter if supplied, else simply use self.weight

Not super jazzed about needing to modify the pytorch source unless its really necessary. So hopefully someone with more than my admittedly short (one day) experience with pytorch will have a better solution.

You can use the functional form of Conv2d (

So would this then be the correct way to do this? (is dropping a functional component into the forward() function valid?)

    def forward(self, x):
        h = self.h(x)
        r = self.r(h)
        q = self.q(r)
        v = torch.max(q, dim=1)[0]
        for i in range(n):
            q = F.conv2d([r, v], 1), 
               [self.q.weight, self.w_fb], 1), 
            v = torch.max(q, dim=1)[0]

yes that’s a way to do it.
And yes, dropping a functional component in forward() is valid.


I want to know the answer from Torch perceptive. If my network looks like this:
a = nn.Sequential()

I would have done:
a:get(1).weight:set(a:get(2).weight) Similar for bias, gradWeight, and gradBias

In PyTorch:
self.fc1 = nn.Linear(10,20)
self.fc7 = nn.Linear(10,20)

What should I do?
self.fc7.weight = self.fc1.weight and similar for bias?
As I see there is no set command. And also linear object has no attribute for ‘gradWeight’!

Would that be enough to share gradWeights and gradBias?


1 Like

Do you mean something like How to create model with sharing weight??

for i in range(n):