I need to achieve my targets by developing computationally efficient code. Here is the scenario:

```
x # torch.Size([2, 5, 256]) (Batch, input_1st_dim, input_2nd_dim)
```

Now I want to concatenate the `input_1st_dim`

index-wise, like *1st* with all the following *four*, then `2nd`

with the next `three`

, then `3rd`

with the next `two`

, then `4th`

with the last `one`

.

Finally, I want to get `x`

as `[2, 15, not sure about this dimension]`

I can do it with nested loops and some extra lines of code. Here is my code:

```
def my_fun(self, x):
iter = x.shape[0]
counter = 0
new_x = torch.zeros((10, x.shape[1]), dtype=torch.float32, device=torch.device('cuda'))
for i in range(0, x.shape[0] - 1):
iter -= 1
for j in range(0, iter):
mean = (x[i, :] + x[i+j, :])/2
new_x[counter, :] = torch.unsqueeze(mean, 0)
counter += 1
final_T = torch.cat((x, new_x), dim=0)
return final_T
```

```
ref = torch.zeros((x.shape[0], 15, x.shape[2]), dtype=torch.float32, device=torch.device('cuda'))
for i in range (x.shape[0]):
ref[i, :, :] = self.my_fun(x[i, :, :])
```

But, is there any computationally efficient way to code it?