Extract batch size and channels from torch tensor (multi modal network)

I have a multi modal network with a forward function as below:

def forward(self, x):    # x.shape torch.Size([1, 2, 200, 100])
    
    x1= x.data[0][0,:,:] # torch.Size([200, 100])
    x2= x.data[0][1,:,:]  # torch.Size([200, 100])
    out_x1 = self.conv(x1)
    out_x2 = self.conv(x2)

The input x is of shape torch.Size([1, 2, 200, 100]), i.e [batch=1, ch=2, height=200, width=100 ] and I want x1 to be first channel and x2 to be second channel. If I use:

x1= x.data[0][0,:,:]

then I only get height and width, torch.Size([200, 100]), which will not work in self.conv, which expect torch.Size([1, 1, 200, 100]). How should I extract x1 and x2 from x to get the 4-dimensions that I need?

Thank you!

You can use the view function to reshape it.

x1 = (x.data[0][0, :, :]).view(1, 1, 200, 100)

Thank you. But then I would not be able to change batch size. This is a nested network so it would be good if I could extract the batch size

My mistake. I misunderstood what you were asking. I’ve tried some code below and got these results.

simulated_batch_size = 64
X = torch.rand((simulated_batch_size, 2, 200, 100), dtype=torch.float64)

x1 = X[:, 0, :, :]
print(x1.shape)  # [64, 200, 100]
x2 = X[:, 1, :, :]
print(x2.shape)  # [64, 200, 100]

#### Reshaped for Conv ####
x1 = x1.view(x1.shape[0], 1, 200, 100)
print(x1.shape) # [64, 1, 200, 100]
x2 = x2.view(x2.shape[0], 1, 200, 100)
print(x2.shape) # [64, 1, 200, 100]

So I believe in your case, it would be

x1= x.data[:, 0, :, :]

It should give you a tensor of shape [1, 200, 100]

Thank you! That works great