inside my model’s forward() method i split input image into three channels, and pass each one through separate modeule. After that, i need to somehow combine each batch of color channels back together
So before merging i have three tensors with shape [batch_size,1,128,128]
After - [batch_size,3,128,128]
How do i do it?
I think you will need to use torch.cat to get the job done.
So here is how I think you should do it.
x = torch.randn(32, 1, 128, 128) # You dont need this part
new_tensor = torch.cat((x,x,x), 1) # to concatinate on the 1 dim Just this part
This should give you the torch.Size([32, 3, 128, 128]) the results you want
Where x is your tensors so you might do it like this
new = torch.cat((a,b,c), 1)