I am using a DataLoader and then
for input_batch, input_labels in data_loader:
as is torch-onic.
However, my model needs to use “split input”, ie it’s forward function needs to do something like this:
def forward(x):
resnet_input, fc_input = x
resnet_out = self.resnet(resnet_input)
out = self.fc_layer(torch.concat(resnet_out, fc_input))
return out
But torch.utils.data.Dataset::getitem is expected to return tensor, tensor, so I don’t know how to do this.
Maybe I should be constructing an input tensor which has another channel which has the extra information in it? And then in forward doing something like:
def forward(x):
resnet_input = x[:,:,:,:3]
fc_input = x[:,:,:,3][0:4,0,0]
That’s pretty ugly though (and I’m just assuming the backpropogation all still works when you’re slicing the input like that…)
This feels like something which has a standard way of doing things, can you please educate me?