hope you are doing well.
I am programming a 3D U-Net, and got output tensors of shape [4,4,128,128,128]. As I assume the minibatch size comes from the torch.cat() function. But now I got an error for the softmax classification with the 4 classes to predict:
ValueError: Expected input batch_size (4) to match target batch_size (1).
I think this is because the mask data have the chape [1,1,128,128,128] but I do not know how to overcome this error. Would someone help me?