Stacking augmented samples in dataloader

I have a network, that takes two augmented images (xi,xj) of the same sample x. I stack the two images in one input in the getitem function.
Now, in my training function, how can I unstack the images (xi,xj) for each input in each batch, before feeding them to the model?
Thank you so much!

You could try stacking the images in a way that allows you to easily retrieve out the 2 images, as shown below.

import torch
batch_size = 16
img_h,img_w = 64,64
img_1 = torch.randn(batch_size,1,img_h,img_w)
img_2 = torch.randn(batch_size,1,img_h,img_w)
### Use this method to stack the inputs in the get item method
concat_inp =,img_2),dim = 1)
### The repective images can be obtained by indexing
inp_1 = concat_inp[:,0]
inp_2 = concat_inp[:,1]
1 Like