Merging 3 image batch tensors

inside my model’s forward() method i split input image into three channels, and pass each one through separate modeule. After that, i need to somehow combine each batch of color channels back together

So before merging i have three tensors with shape [batch_size,1,128,128]

After - [batch_size,3,128,128]

How do i do it?

I think you will need to use torch.cat to get the job done.

So here is how I think you should do it.

x = torch.randn(32, 1, 128, 128) # You dont need this part
new_tensor = torch.cat((x,x,x), 1) # to concatinate on the 1 dim Just this part

This should give you the torch.Size([32, 3, 128, 128]) the results you want
Where x is your tensors so you might do it like this
new = torch.cat((a,b,c), 1)

1 Like