I am struggling with finding out if an idea that I had is possible in pytorch.
My goal is to have an image classifier, that selects the most relevant reference image for a reference based super resolution model.
(green: forward pass, blue:backward)
So first, an image is sent to a classifier. That classifiers output should select one image from a fixed stack of images. This output is then sent to the super resolution model (SRM).
Is it possible to map a classification output to discrete 3d-classes without loosing the possibility to differntiate the whole thing?
import torch import torch.nn as nn import torchvision.models as models class RefSelector(nn.Module): def __init__(self): super(RefSelector, self).__init__() self.args = args self.RSel = models.resnet18(pretrained=True) for param in self.RSel.parameters(): param.requires_grad_(False) self.RSel.fc = nn.Linear(in_features=512, out_features=10, bias=True) def forward(self, LR): return self.RSel(LR) RefSelector = RefSelector() inputImage = torch.rand(3,128,128) refImages = torch.rand(10,3,128,128) output = RefSelector(inputImage)
My first idea was to just use any image classifier e.g. ResNet, but of course if I apply
ReferenceID = torch.argmax(output, dim=1)
to that and then select the right reference image with that, there is no way a gradient could be calculated in the backward pass, because the reference image has nothing to do at all with the classifiers output.
Has anyone a vague idea how, or if this could be possible?