Huge loss with DataParallel

The DataParallel module is pretty straightforward: it splits the input up into N chunks (typically across the first dimension), runs the same forward pass on N replicas of your model, and gathers the output back into a single tensor (across the same dimension as the input was split). Gradients are always accumulated in the source model (not the replicas).

It looks like updates to buffers in these replicas don’t propagate back to the source model, as the model replicas are tossed after every forward pass. Perhaps this is a starting point for your investigation?