Apply vectorized random permutations to conv2d weight tensors

I want to have a method that will take the weight of a conv2d layer, typically in the shape C_out, C_in, H, W, where C_out represents the output channels, C_in represents the input channels, while H and W are the dimensions of filters (kernels).

As an example, consider a weight:

w = torch.randint(0, 5, (4, 1, 3, 3))
tensor([[[[3, 3, 1],
          [3, 4, 2],
          [1, 2, 3]]],


        [[[3, 0, 1],
          [3, 1, 3],
          [0, 2, 2]]],


        [[[4, 0, 1],
          [2, 4, 3],
          [2, 4, 0]]],


        [[[0, 0, 4],
          [2, 2, 1],
          [3, 4, 0]]]])

For this weight, I would like to apply line-wise permutations for every filter. This means taking each of the four filters, which can be considered as a 3x3 matrix and randomly permute the lines.

I can create this implementation but only through a naive for loop:

def test_conv():
    w = torch.randint(0, 5, (4, 1, 3, 3))

    w_permuted = w.clone()
    for idx, kernel in enumerate(w_permuted):
        # generate random permutation
        key = torch.randperm(kernel.shape[1])
        # permute lines of the current 3x3 kernel with generated indices
        w_permuted[idx] = kernel[:, key]

    print(w)
    print(w_permuted)

Q: Is there a more elegant approach using PyTorch to generate w_permuted in a vectorized approach?

Thanks in advance

If I understand your use case correctly you want to permute the rows of all kernels randomly without mixing them between kernels.
I think your code is fine as it’s clearly showing your implementation and I don’t know if this is currently a bottleneck in your code.
However, you could try to use torch.gather as seen here and compare it against your approach:

w = torch.arange(4*2*3*3).view(4, 2, 3, 3).to(device)
w_ = w.clone()
key = torch.stack([torch.randperm(w_.shape[2]) for _ in range(w_.shape[0] * w_.shape[1])]).to(device)
torch.gather(w_, 2, key.view(4, 2, 3, 1).expand_as(w_))

Hi Robert!

Yes. The largest inefficiency in your approach would be the python loop. This can be
avoided by using argsort() to compute a batch of permutations without any python loop.

Please see this post:

As noted in the linked post, argsort() has O (n log (n)) cost, with a factor of log (n)
that is theoretically unnecessary when generating permutations. But for your example, your
n is 3, so the asymptotic inefficiency is fully irrelevant (as it will be for an n coming from
any reasonable convolution kernel).

Best.

K. Frank

1 Like

Thanks a lot for providing this

1 Like

Interesting proposal!!! I was not aware of argsort()