Classification with multi-dimensional labels

Hi everyone,
I am struggling with finding out if an idea that I had is possible in pytorch.

My goal is to have an image classifier, that selects the most relevant reference image for a reference based super resolution model.


(green: forward pass, blue:backward)

So first, an image is sent to a classifier. That classifiers output should select one image from a fixed stack of images. This output is then sent to the super resolution model (SRM).

Is it possible to map a classification output to discrete 3d-classes without loosing the possibility to differntiate the whole thing?

import torch
import torch.nn as nn
import torchvision.models as models

class RefSelector(nn.Module):
    def __init__(self):
        super(RefSelector, self).__init__()
        self.args = args
        self.RSel = models.resnet18(pretrained=True)
        for param in self.RSel.parameters():
            param.requires_grad_(False)
        self.RSel.fc = nn.Linear(in_features=512, out_features=10, bias=True)
    
    def forward(self, LR):
        return self.RSel(LR)

RefSelector = RefSelector()
inputImage = torch.rand(3,128,128)
refImages = torch.rand(10,3,128,128)
output = RefSelector(inputImage)

My first idea was to just use any image classifier e.g. ResNet, but of course if I apply

ReferenceID = torch.argmax(output, dim=1)

to that and then select the right reference image with that, there is no way a gradient could be calculated in the backward pass, because the reference image has nothing to do at all with the classifiers output.
Has anyone a vague idea how, or if this could be possible?