Hi everyone,
I am struggling with finding out if an idea that I had is possible in pytorch.
My goal is to have an image classifier, that selects the most relevant reference image for a reference based super resolution model.
(green: forward pass, blue:backward)
So first, an image is sent to a classifier. That classifiers output should select one image from a fixed stack of images. This output is then sent to the super resolution model (SRM).
Is it possible to map a classification output to discrete 3d-classes without loosing the possibility to differntiate the whole thing?
import torch
import torch.nn as nn
import torchvision.models as models
class RefSelector(nn.Module):
def __init__(self):
super(RefSelector, self).__init__()
self.args = args
self.RSel = models.resnet18(pretrained=True)
for param in self.RSel.parameters():
param.requires_grad_(False)
self.RSel.fc = nn.Linear(in_features=512, out_features=10, bias=True)
def forward(self, LR):
return self.RSel(LR)
RefSelector = RefSelector()
inputImage = torch.rand(3,128,128)
refImages = torch.rand(10,3,128,128)
output = RefSelector(inputImage)
My first idea was to just use any image classifier e.g. ResNet, but of course if I apply
ReferenceID = torch.argmax(output, dim=1)
to that and then select the right reference image with that, there is no way a gradient could be calculated in the backward pass, because the reference image has nothing to do at all with the classifiers output.
Has anyone a vague idea how, or if this could be possible?