Torch.view is painful to use. Is there a function for spliting or merging dimensions?

Most of the time when I am trying to define a custom operation, mostly I am not concerned with rest of dimensions i.e defining multi-head attention for convs. I am not concerned with batch, width, height or even what conv operation is 1d, 2d or 3d. I am just concerned with feature dimension (dim=1) and I am not concerned with the rest of dimensions. This is where view and permute become really painful to use. Lets just limit this thread to torch.view.

Now I just want to split feature dimension into two without converting it to list and having to concat it back again? Why is there no function to do so? or at least I couldn’t find one. Lets suppose I have tensor like (B, C, H, W) and I want to split channel into into N chunks.

There should be a function like torch.split(x, dim=1, (C // N, N)) which should result in (B, C//N, N, H, W) shape. Using torch.view makes this so painful especially when you are writing a generalized operation for all convolutions. I have to split shape and do multiple operations just to achieve this result. And on top of that it ruins pytorch tracing when you involve shapes. There should also be torch.merge function that should merge dimensions (reverse of split). torch.merge(x, (1, 2)) should revert (B, C//N, N, H, W) to (B, C, H, W) it can save so much time.

Same issue with permute. If I want to swap two dimensions I have to specify all dimensions. I am not concerned with W,H or L at the end. It become really hard to write generalized operations due to hard-coded approach we currently have for view and permute.

In your specific example given, there is a chunk method:

x=torch.rand(16, 3, 32, 32)

x1, x2, x3 = x.chunk(3, dim=1)

torch.cat can reverse the op.

y = torch.cat([x1, x2, x3], dim=1)

You can also simply specify the ranges, like in numpy:

x1, x2, x3 = x[:,0:1,...], x[:, 1:2, ...], x[:, 2:3,...]

Additionally, torch.einsum is a good kit to use for miscellaneous tensor operations.

That sounds very very inefficient. The computation graph is going to be mess and this operation alone would be huge bottleneck in training if we have to split tensor into many chunks i.e 64, 128.

Its also going to rely on python list and GIL so might not work with pytorch tracing. This might be way more inefficient than view method.

Pytorch is an open source community and developed by community members. If you have an idea that you believe would help improve the ecosystem, you can submit the code for inclusion:

https://pytorch.org/docs/stable/community/contribution_guide.html

1 Like