How to fix gathering dim 0 warning in multi-gpu (DataParallel) setting?

I just had the error this week myself, the cause is that you have a function/method (probably the forward method of one or your modules) that returns a scalar, like an accuracy or loss function for example.

See this issue for more details. This is just a warning, but it could mean that your code doesn’t work properly.

For example, my use case was a learnable loss which was shared across GPUs with DataParallel, but returned only a scalar in the forward function, which was then stacked in a vector and caused an error further on. The solution to this specific problem was to call loss = loss.mean(), which is what should happen when reassembling the values together anyway.

1 Like