How to share a convolution kernel's transpose to another kernel?

Well, I am going to build a CNN including two conv layers. I’m expecting the two layers share one kernel, one of which use the kernel’s transpose.
Here’s what I’ve written:

self.conv2.weight[0,0] = self.conv1.weight[0,0].t()

But I got such error when executing it:
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

How should I solve it? Thanks in advance.

You could try to manipulate the data:

conv1.weight.data[0, 0] = conv2.weight.data[0, 0].t()

Thanks for your help, but I’ve tried this way and it doesn’t work——It just init the kernel’s value to another instead of sharing their weights in the whole process.:pensive:

If you wanted to share entire weights tensors, you could do…

conv2.weight = conv1.weight.t()

and then train as usual. PyTorch should combine the updates for conv1.weight and conv2.weight automatically and keep both weight tensors in sync without any fuss.

But you want to share only part of the weight tensors, so I think you are going to have to monitor the weight updates and combine them manually.

# store weights before update
w1 = conv1.weight[0,0].data
w2 = conv2.weight[0,0].data
# make update
optimizer.step()
# calculate updates
update1 = conv1.weight[0,0].data - w1
update2 = conv2.weight[0,0].data - w2
# combine updates
conv1.weight[0,0].data += update2.t()
conv2.weight[0,0].data += update1.t()

Thank you man. The reason I do

self.conv2.weight[0,0] = self.conv1.weight[0,0].t()

is that the convolution kernel is a 4d tensor, i.e.:

[torch.FloatTensor of size output_chanel x input_channel x kernel_height x kernel_width]

So I cannot just use conv2.weight = conv1.weight.t() to transpose all of the kernels.

But even I replace conv with linear layer, like that:

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(5, 10)
        self.linear2 = nn.Linear(10, 5)

        self.linear2.weight = self.linear1.weight.t()

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)

        return x

net = Net()

Now I’m trying to share the whole weight tensor, but it still doesn’t work and shows that:

Traceback (most recent call last):
  File "/home/hdl2/Desktop/pytorchstudy/test.py", line 23, in <module>
    net = Net()
  File "/home/hdl2/Desktop/pytorchstudy/test.py", line 15, in __init__
    self.linear2.weight = self.linear1.weight.t()
  File "/home/hdl2/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 281, in __setattr__
    .format(torch.typename(value), name))
TypeError: cannot assign 'torch.autograd.variable.Variable' as parameter 'weight' (torch.nn.Parameter or None expected)

Then how should I do?

Your code self.conv2.weight[0,0] = self.conv1.weight[0,0].t() only transposes the weight matrix for the part of the convolution that takes the first channel of the input to the first channel of the output, it doesn’t do anything for the other parts of the convolution.

If you want to transpose the convolutional matrices for all channels you can use .transpose(2,3) to swap the last two dimensions of a 4D tensor.

It looks like my idea for weight sharing doesn’t work at all, and a quick search of the forum shows that the only easyish way to do it is to use the functional form of conv2d.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_weight = nn.Parameter(torch.randn(1,1,5,5))
    
    def forward(self, x):
        x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
        x = nn.functional.conv2d(x, self.conv_weight.transpose(2,3), bias=None, stride=1, padding=0, dilation=1, groups=1)
        return x
3 Likes

Bravo, this problem seems to be solved.
But when I run your code above on my GPU it shows:

RuntimeError: cuDNN requires contiguous weight tensor

while it works on CPU, and works on GPU if I replace conv to linear(keeping transpose unchanged). Does this error attribute to the transposing operation on convolution, and rules of CUDA?

I don’t know why cuDNN gives this error. As far as I know .transpose works by changing the strides of the view mechanism, not by moving the data, so if the data was contiguous before the transpose, then it still is afterwards. I suppose cuDNN could be a little picky about non-conventional view strides.

Replacing the 2nd call to .conv2d with the following might work… though I think it will end up making a copy the weight tensor on every forward pass.

x = nn.functional.conv2d(x, self.conv_weight.transpose(2,3).contiguous(), ...)

Thank you very much, it’s been solved.