DataParallel, Expected input batch_size (64) to match target batch_size (32)

model = nn.DataParallel(model, device_ids=[0, 1])

context, ctx_length = batch.context
response, rsp_length = batch.response
label = batch.label

prediction = self.model(context, response)
loss = self.criterion(prediction, label)

the batch size is 32, the size of prediction is 32 * 2 = 64, but the size of the label is still 32, which cause the criterion to raise Error.

What’s the problem?

2 Likes

How come the predictions doesn’t match the labels? you should set the output of the prediction to 1 (will be 32 if the batch size is 32).

Acctually, the size of prediction is [64, 2], however, the size of label is [32]. That’s what I am confused. When I change device_ids=[0], the size of predication become [32, 2].

I don’t know what model you’re using so it’s hard for me to help you, but to find the problem I would probably change the batch size to 1 and see what happens, just to make sure that the model’s output size is really 1.

Hi. Were you able to resolve this?
I am also using DataParallel to use multiple GPUs in Pytorch script but face a similar error with batch size 64 and 4 GPUs:

ValueError: Expected input batch_size (256) to match target batch_size (64)

It is too late to answer, but I leave the answer here for other people.
I recently had the same problem for multi-gpu usage.

Problem: It seems like the problem is that we need to pass Tensor-type variable to model.forward for DataParallel.
Otherwise, it is not going to be scattered, but copied as the number of gpu.

Solution: easy solution is to pass Tensor variable to model.forward, instead of type casting in the forward function.
Another solution is to return loss from the forward function, then normalize loss by the batch size.

Please see here for detail about DataParallel

1 Like

@sdeva14 Hey can you please post the first solution how to do that, it is not becoming clear to me.

Thanks in advance :slight_smile:

@Aayushee_Gupta Hey, did you resolve this?

@zeng Hey, how did you resolve this?