Loss function of multiple inputs and one target


I am now dealing with the combined images of two mnist digits. And the only label I have is the mod 10 of the sum of two digits. I built the network as the following:

class module(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2):
        super(module, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.Linear(hidden_dim2, 10),

    def forward(self, input1, input2):
        output = torch.cat((input1, input2), dim=0) 
        output = output.view(output.size(0), -1)       
        output = self.layer(output)            

        output = torch.split(output, input1.size(0))
        return output[0], output[1]

There are now two separated outputs from the given image. Then I need to make a calculation using these two outputs to match one target value(=mod10(image1_label, image2_label)).

Here is where I am stuck for a couple of days. I have no idea how to build the loss function using these two inputs and one target, and still have a doubt it is really a classification problem, so it’s the right decision to use the CrossEntropyLoss or not. So, I would appreciate it if you can give me any hints to solve it out so that I can move on.

model = Module(input_dim=28*28, hidden_dim1=2048, hidden_dim2=1024).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

def train_model(model, dataloader, optimizer, loss_fn):
    losses = []
    predictions = 0
    for images1, images2, labels, _,_ in dataloader:
        images1 = images1.to(device)
        images2 = images2.to(device)
        labels = labels.to(device)

        output1, output2 = model(images1, images2)

        ??loss = loss_fn(????????, labels)
        predicted = (output1.argmax(1)+output2.argmax(2)) % 10
        predictions += accuracy(labels, predicted)
    train_loss  = sum(losses) / len(losses)
    train_accuracy = 100 * predictions / len(dataloader.dataset)
    return train_loss, train_accuracy