I am trying to train a network on mixed data, for example a single input might be an image and a coordinate, and I want to preprocess each one separately and then feed them into a multilayer linear network. Here is what I am starting with:
class TwoInputNet(torch.nn.Module):
def __init__(self, img_channels,h_img,w_img,conv_channels_img,loc_hidden,loc_out, feature_dim, hidden_dim):
super(TwoInputNet, self).__init__()
convw_img = conv2d_size_out(conv2d_size_out(conv2d_size_out(w_box)))
convh_img = conv2d_size_out(conv2d_size_out(conv2d_size_out(h_box)))
state_dim = (convw_img * convh_img * 2 * conv_channels_img)+loc_out
self.img_preprocessor = torch.nn.Sequential(
torch.nn.Conv2d(img_channels,conv_channels_img, kernel_size = 5,stride = 2),
torch.nn.BatchNorm2d(conv_channels_img),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(conv_channels_img,2*conv_channels_img,kernel_size = 5,stride = 2),
torch.nn.BatchNorm2d(2*conv_channels_img),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(2*conv_channels_img,2*conv_channels_img,kernel_size = 5,stride = 2),
torch.nn.BatchNorm2d(2*conv_channels_img),
torch.nn.LeakyReLU(),
torch.nn.Flatten()
)
self.loc_preprocessor = torch.nn.Sequential(
torch.nn.Linear(2,loc_hidden),
torch.nn.LeakyReLU(),
torch.nn.Linear(loc_hidden,2*loc_hidden),
torch.nn.LeakyReLU(),
torch.nn.Linear(2*loc_hidden,loc_out),
torch.nn.LeakyReLU()
)
self.collected = torch.nn.Sequential(
torch.nn.Linear(state_dim, hidden_dim),
torch.nn.LeakyReLU(),
torch.nn.Linear(hidden_dim, hidden_dim*2),
torch.nn.LeakyReLU(),
torch.nn.Linear(hidden_dim*2, action_dim)
)
def forward(self,ipt):
img_obs,loc_obs = ipt
x1 = self.footage_preprocessor(footage_obs)
x2 = self.loc_preprocessor(loc_obs)
x = torch.cat([x1,torch.unsqueeze(x2,0)],dim = 1)
x = self.collected(x)
return x
I added the unsqueeze
to the forward
function because the img_preprocessor
output a different shape from the loc_preprocessor
. However, this causes problems when doing batches since a batch is a list of lists and not a list of tensors so torch.cat
doesn’t work. How can I format my inputs so that the batching doesn’t break?