So are you multiplying the batch size by the number of GPUs (9)?
nn.DataParallel
will chunk the batch in dim0 and send each piece to a GPU.
Since you get [10, 396]
inside the forward
method for a single GPU as well as for multiple GPUs using nn.DataParallel
, your provided batch should have the shape [90, 396]
before feeding it into the nn.DataParallel
model.
Is my assumption correct?
1 Like