Error in torch.nn.DataParallel

So are you multiplying the batch size by the number of GPUs (9)?
nn.DataParallel will chunk the batch in dim0 and send each piece to a GPU.
Since you get [10, 396] inside the forward method for a single GPU as well as for multiple GPUs using nn.DataParallel, your provided batch should have the shape [90, 396] before feeding it into the nn.DataParallel model.
Is my assumption correct?

1 Like