I want to have a method that will take the weight of a conv2d
layer, typically in the shape C_out, C_in, H, W
, where C_out
represents the output channels, C_in
represents the input channels, while H
and W
are the dimensions of filters (kernels).
As an example, consider a weight:
w = torch.randint(0, 5, (4, 1, 3, 3))
tensor([[[[3, 3, 1],
[3, 4, 2],
[1, 2, 3]]],
[[[3, 0, 1],
[3, 1, 3],
[0, 2, 2]]],
[[[4, 0, 1],
[2, 4, 3],
[2, 4, 0]]],
[[[0, 0, 4],
[2, 2, 1],
[3, 4, 0]]]])
For this weight, I would like to apply line-wise permutations for every filter. This means taking each of the four filters, which can be considered as a 3x3
matrix and randomly permute the lines.
I can create this implementation but only through a naive for
loop:
def test_conv():
w = torch.randint(0, 5, (4, 1, 3, 3))
w_permuted = w.clone()
for idx, kernel in enumerate(w_permuted):
# generate random permutation
key = torch.randperm(kernel.shape[1])
# permute lines of the current 3x3 kernel with generated indices
w_permuted[idx] = kernel[:, key]
print(w)
print(w_permuted)
Q: Is there a more elegant approach using PyTorch to generate w_permuted
in a vectorized approach?
Thanks in advance