Hi, I have two tensors tensor1 and tensor2, with shape (16, 768, 64, 64) and I want to stack them such that I have a tensor output with shape (16,2,768,64,64), BUT I also want the channel dimension 768 to be an alternating zip of tensor1 and tensor2.
tensor1[:, 0, :, :] and tensor2[:, 0, :, :] are stacked together as a (16, 2, 64, 64) tensor and so on and then we stack that list of stacked tensors again…
I don’t think there is a way to do this with torch.cat? So I wrote this little func that should do it for me

def zip_and_stack_tensors(tensor1, tensor2):
num_of_channels = len(tensor1[0, :, 0, 0])
imagelist = []
for i in range(0, num_of_channels):
tensorimage = torch.stack((tensor1[:, i, :, :], tensor2[:, i, :, :]), dim=1)
imagelist.append(tensorimage)
output = torch.stack(imagelist, dim=2)
return output

However, the backwardpass now takes about x8 times longer if I call this function; the forwardpass time does not change considerably. I guess because currently torch.stack gets called 768 + 1 times?
Is there a way to prevent this/write my function more efficient?

I’m not sure to understand what you’re trying to do here. You can get the same result as your zip_and_stack_tensors() using a single torch.stack(). Am I missing something?

import torch
from torch import nn, optim
def zip_and_stack_tensors(tensor1, tensor2):
num_of_channels = len(tensor1[0, :, 0, 0])
imagelist = []
for i in range(0, num_of_channels):
tensorimage = torch.stack((tensor1[:, i, :, :], tensor2[:, i, :, :]), dim=1)
imagelist.append(tensorimage)
output = torch.stack(imagelist, dim=2)
return output
t1 = torch.rand(16, 768, 64, 64)
t2 = torch.rand(16, 768, 64, 64)
out = zip_and_stack_tensors(t1, t2)
out_opt = torch.stack([t1, t2], 1)
print(out.size() == out_opt.size())
print((out - out_opt).abs().max())

Yes, the shape is the same in both cases, but the order of the depth layer is different as only using stack does not zip the tensors alternatingly:
Lets say I have
t1 = torch.135(16, 3, 64, 64), where:
t1[:, 0, :, :] = all 1s
t1[:, 1, :, :] = all 3s
t1[:, 2, :, :] = all 5s
and
t2 = torch.246(16, 3, 64, 64), where:
t2[:, 0, :, :] = all 2s
t2[:, 1, :, :] = all 4s
t2[:, 2, :, :] = all 6s

my function creates a tensor t3 where the depth is in the order 1,2,3,4,5,6, whereas only using stack creates a tensor t3 where the depth is in the order 1,3,5,2,4,6 or not?

In the code sample, I print the difference: print((out - out_opt).abs().max()) and it gives 0 for me. So the stack does the same thing as the function you shared.
Is the function you shared doing what you want?

You’re absolutely right. I just had a massive error in my conception of how the dim parameter works in torch.stack.
But yeah, this makes total sense now
Thanks for the help!