Defining "forward" function properly for mixed data

I am trying to train a network on mixed data, for example a single input might be an image and a coordinate, and I want to preprocess each one separately and then feed them into a multilayer linear network. Here is what I am starting with:

class TwoInputNet(torch.nn.Module):
    def __init__(self, img_channels,h_img,w_img,conv_channels_img,loc_hidden,loc_out, feature_dim, hidden_dim):
        super(TwoInputNet, self).__init__()
        convw_img = conv2d_size_out(conv2d_size_out(conv2d_size_out(w_box)))
        convh_img = conv2d_size_out(conv2d_size_out(conv2d_size_out(h_box)))
        state_dim = (convw_img * convh_img * 2 * conv_channels_img)+loc_out
        self.img_preprocessor = torch.nn.Sequential(
                        torch.nn.Conv2d(img_channels,conv_channels_img, kernel_size = 5,stride = 2),
                        torch.nn.BatchNorm2d(conv_channels_img),
                        torch.nn.LeakyReLU(),
                        torch.nn.Conv2d(conv_channels_img,2*conv_channels_img,kernel_size = 5,stride = 2),
                        torch.nn.BatchNorm2d(2*conv_channels_img),
                        torch.nn.LeakyReLU(),
                        torch.nn.Conv2d(2*conv_channels_img,2*conv_channels_img,kernel_size = 5,stride = 2),
                        torch.nn.BatchNorm2d(2*conv_channels_img),
                        torch.nn.LeakyReLU(),
                        torch.nn.Flatten()
                )
        
                
        self.loc_preprocessor = torch.nn.Sequential(
                        torch.nn.Linear(2,loc_hidden),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(loc_hidden,2*loc_hidden),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(2*loc_hidden,loc_out),
                        torch.nn.LeakyReLU()
                )
        
        self.collected = torch.nn.Sequential(
                        torch.nn.Linear(state_dim, hidden_dim),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(hidden_dim, hidden_dim*2),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(hidden_dim*2, action_dim)
                )
    def forward(self,ipt):
        img_obs,loc_obs = ipt
        x1 = self.footage_preprocessor(footage_obs)
        x2 = self.loc_preprocessor(loc_obs)
        x = torch.cat([x1,torch.unsqueeze(x2,0)],dim = 1)
        x = self.collected(x)
        return x

I added the unsqueeze to the forward function because the img_preprocessor output a different shape from the loc_preprocessor. However, this causes problems when doing batches since a batch is a list of lists and not a list of tensors so torch.cat doesn’t work. How can I format my inputs so that the batching doesn’t break?

If your input batch contains a list of lists, I assume each tensor in the nested list has a different shape?
If that’s the case, you could pad the tensors so that they have the same shape, and stack them to a single tensor.
Currently PyTorch modules do not accept a batch with differently shaped tensors out of the box.

1 Like