Avoid reshpe operation to change channel dimension

Problem: I have data in format B, a, b, c, d. (Batch size, followed by some dimensions (a, b, …)
Now I want to view this data along different dimension (i.e. that dimension as first dimension) and apply some transoms in this view.
For instance when, viewed in dimension of c it will be of shape B, c, -1.
viewed as d it will be of shape B, d, -1.

My current approach is resulting in too much reshaping, which is causing very slow gradient calculation and little slow forward pass. Any hints to improve this will be helpful. Sorry for using this forum as bouncing board.

            for _ in range(K):
                shape = x.shape()
                d = random.randint(1, 4)
                #  inital = [0, 1, 2, 3, 4]
                final = [0, d, 2, 3, 4]
                final[d] = 1
                x = x.transpose(d, 1)
                x = x.reshape(size[0], size[d], -1)
                x = self.layer[d - 1](x)
                x = x.reshape([
                    size[final[0]], size[final[1]], size[final[2]],
                x = x.transpose(d, 1)

Update: It can be restated as, apply linear layer to flattened view of last k dimension of non contiguous tensor without copying it.

@avish you should be able to use torch.transpose

b=torch.transpose(a, 3, 1)
print([a.size(), b.size()])

Is this not what you wanted?

Yes, this is what I am currently doing. So lets say, after this we want take last 3 dimension apply flatten(2) and pass it to nn.Linear(2*3*3, op_dim). So this flatten operation in between makes a copy of tensor (as original tensor can not be viewed because it’s stride is not compatible). This is currently causing gradient calculation too slow.

@avish understood. That would be tricky without creating a copy of the tensor. The reason is that tranpose function leverages the same underlying storage of the original tensor

To illustrate the issue check the below code. After the transpose the view command gives an error but the same view command does not give an error on an “original” tensor

print(a.view(2, 1, 2, -1).size())

print(a.view(2, 1, 2, -1).size())

b=torch.transpose(a, 3, 1)
print(b.view(2, 1, 2, -1))
torch.Size([2, 1, 2, 9])
torch.Size([2, 1, 2, 9])
torch.Size([2, 1, 2, 3, 3])
RuntimeError                              Traceback (most recent call last)
<ipython-input-39-b6aaccdd97cd> in <module>()
      8 b=torch.transpose(a, 3, 1)
      9 print(b.size())
---> 10 print(b.view(2, 1, 2, -1))

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Will try to find a way to achieve this. There are several discuss threads and also a github thread on the same

Thanks, if possible can you share those threads, I was not able to find any.
For inspiration I was trying to implement something like mlp mixer but with more dimensional data.