Intra-layer tied weights in a convolution layer

I’m using the functional interface to conv2d and am looking for an efficient way to define a weight matrix W which has the following constraint: W[i, j, a, b] = W[j, i, ks-1-a, ks-1-b] for all i, j, a, b. ks is the kernel size

How should I create and register W as a Parameter in my object constructor? It seems challenging to create given the tied weights within W.

I’ve tried defining the following in my constructor

W0 = nn.Parameter(torch.zeros(nf, nf, ks, ks), requires_grad=True)

and then creating a W in my forward function via

W0_flipped = flip(flip(W0, 3), 2) W = (W0 + W0_flipped.permute(1, 0, 2, 3)) / 2.0

where “flip” is defined as in https://github.com/pytorch/pytorch/issues/229. The resulting W has the appropriate constraints, but this definition comes at the cost of creating many redundant parameters in W0. I then use W in calls like

f = nn.functional.conv2d(inputs, W, bias=b, stride=1, padding=pad_f)

How can I create this without redundant parameters?

Hi,

Here is a code sample that generate the convolution parameters from the minimal set of parameters:

import math
import torch

nf = 3
ks = 3

diff_chan_params = torch.rand(nf*(nf-1)/2, ks, ks)
same_chan_params = torch.rand(int(math.ceil(float(nf)/2)), ks, ks)
print("learnable parameters")
print(diff_chan_params)
print(same_chan_params)

def get_params(chan_in, chan_out):
    if chan_in == chan_out:
        s = same_chan_params[chan_in/2]
        if chan_in%2 == 0:
            triang = s.triu()
        else:
            triang = s.tril()
        return triang + triang.t()
    small_idx = min(chan_in, chan_out)
    large_idx = max(chan_in, chan_out)
    lin_idx  = small_idx + large_idx * (large_idx - 1) / 2
    param = diff_chan_params[lin_idx]
    if chan_in < chan_out:
        param = param.t()
    return param

curr_idx = 0
conv_params = []
for out_c in range(nf):
    tmp = []
    for in_c in range(nf):
        tmp.append(get_params(in_c, out_c))
    conv_params.append(torch.stack(tmp, 0))
full_params = torch.stack(conv_params, 0)
print("conv parameters")
print(full_params)
1 Like

Thanks so much albanD. This was super helpful. Really appreciate it.