Compute the loss of a moving dataset

Hi guys,
I’m new to the world of machine learning, so it could be that my question is trivial or incorrectly posed.

I am using a moving dataset that I have forwarded to an STN network (Spatial Tranformation Network). To the STN I forward each image individually, then restack the whole images together in a tuple.
My problem lies in the loss calculation. My target has a torch.tensor with this size [2,1,64,64]
and my prediction that I want to implement has a torch.tensor [2,1,10,64,64], which means that the prediction and the target are not the same.
Could someone explain an idea to me. The only idea I have is to return the last STN output meaning something like this [2,1,1,64,64] to my prediction and then squeezed to be [2,1,64,64] and then calculate the loss.

Thank you in advance