Reduce number of Minibatches

Hey,

hope you are doing well.

I am programming a 3D U-Net, and got output tensors of shape [4,4,128,128,128]. As I assume the minibatch size comes from the torch.cat() function. But now I got an error for the softmax classification with the 4 classes to predict:

ValueError: Expected input batch_size (4) to match target batch_size (1).

I think this is because the mask data have the chape [1,1,128,128,128] but I do not know how to overcome this error. Would someone help me?

I don’t know where the torch.cat function is used, but note that your model should not change the batch size. Since your target contains a single sample I would assume your model input was also originally a single sample?

1 Like

Yes I guess I had a wrong init in the test/train split. thx!