Most of the time when I am trying to define a custom operation, mostly I am not concerned with rest of dimensions i.e defining multi-head attention for convs. I am not concerned with batch, width, height or even what conv operation is 1d, 2d or 3d. I am just concerned with feature dimension (dim=1) and I am not concerned with the rest of dimensions. This is where view and permute become really painful to use. Lets just limit this thread to torch.view.
Now I just want to split feature dimension into two without converting it to list and having to concat it back again? Why is there no function to do so? or at least I couldn’t find one. Lets suppose I have tensor like (B, C, H, W) and I want to split channel into into N chunks.
There should be a function like torch.split(x, dim=1, (C // N, N)) which should result in (B, C//N, N, H, W) shape. Using torch.view makes this so painful especially when you are writing a generalized operation for all convolutions. I have to split shape and do multiple operations just to achieve this result. And on top of that it ruins pytorch tracing when you involve shapes. There should also be torch.merge function that should merge dimensions (reverse of split). torch.merge(x, (1, 2)) should revert (B, C//N, N, H, W) to (B, C, H, W) it can save so much time.
Same issue with permute. If I want to swap two dimensions I have to specify all dimensions. I am not concerned with W,H or L at the end. It become really hard to write generalized operations due to hard-coded approach we currently have for view and permute.