I am trying to solve a problem, where I have a set of images divided into N classes. I want to create a network which takes 2 of the images from this set, and returns 2 output classes - Class 1 if both images are from the same class, and Class 0 otherwise. This gives rise to a set of problems:
- DataLoader: As seen from above, if I have M distinct images, I also have M*(M-1)/2 (i.e. choose 2 of M) ways I can populate my network. Is there a way to iterate through all possible situations? I tried loading 2 dataloaders aimed at the same image (i.e., 2 datasets pointing to the same ImageFolder path, which subsequently feed into 2 dataloaders). Then, I call the dataloaders using a nested for loop, i.e.,
for phase in ['train', 'val']: for data1 in dataloader1: for data2 in dataloader2: #val model
However, when my phase changes from training to validating, this seems to cause an out of memory error in cuda. How do I solve this? Is it possible to create a DataLoader that serves all possible combinations?
- Class: Currently, I evaluate a new variable which compares the classes of the old data:
RealClass = int(Class1 == Class2)
Is this the most efficient way to handle this problem? Can I use some custom dataloader to create new classes while serving the M*(M-1)/2 image inputs?
- Forward pass: The model I am making needs to “boil down” the image into K features. To do so, I have created 2 sub-models as follows:
Model1: Input - Image. Output - Tensor of length K.
Model2: Input - Tensor of length 2K. Output - 2 output classes (as mentioned above)
For each forward pass, I run each image through Model1, and then concatenate the two K length tensor (dim = 1) to get a 2K length tensor, that I pass to Model2. Then I backpropagate the error by comparing the output class to the expected class. Is it possible to do away with Model2, and just calculate the cartesian distance between the two K length vector and set a threshold (as a parameter) to classify into 2 classes? How do I set it up to ensure smooth backpropagation?
Thank you for your time,