Problem: I have data in format B, a, b, c, d. (Batch size, followed by some dimensions (a, b, …)
Now I want to view this data along different dimension (i.e. that dimension as first dimension) and apply some transoms in this view.
For instance when, viewed in dimension of c it will be of shape B, c, -1.
viewed as d it will be of shape B, d, -1.
My current approach is resulting in too much reshaping, which is causing very slow gradient calculation and little slow forward pass. Any hints to improve this will be helpful. Sorry for using this forum as bouncing board.
for _ in range(K):
shape = x.shape()
d = random.randint(1, 4)
# inital = [0, 1, 2, 3, 4]
final = [0, d, 2, 3, 4]
final[d] = 1
x = x.transpose(d, 1)
x = x.reshape(size[0], size[d], -1)
x = self.layer[d - 1](x)
x = x.reshape([
size[final[0]], size[final[1]], size[final[2]],
size[final[3]]
])
x = x.transpose(d, 1)
Update: It can be restated as, apply linear layer to flattened view of last k dimension of non contiguous tensor without copying it.