How to fix gathering dim 0 warning in multi-gpu (DataParallel) setting?

I am continuously facing this warning but have no idea what might be the issue, as it did not return which line of code triggers this. Does anyone have thoughts on what the common issue(s) could be?

python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.

I just had the error this week myself, the cause is that you have a function/method (probably the forward method of one or your modules) that returns a scalar, like an accuracy or loss function for example.

See this issue for more details. This is just a warning, but it could mean that your code doesn’t work properly.

For example, my use case was a learnable loss which was shared across GPUs with DataParallel, but returned only a scalar in the forward function, which was then stacked in a vector and caused an error further on. The solution to this specific problem was to call loss = loss.mean(), which is what should happen when reassembling the values together anyway.

Thanks for sharing the insights. I did use loss = loss.mean() before calling backward. Do you know whether there is an easy way to check which line exactly causes the problem?

I don’t know if you can check the internals of DataParallel, but in my case I simply printed the loss value inside the forward method, which was a scalar, and the loss value outside the method (thus inside the training script), which was a vector.

Actually, just re-read the issue on github, and I realized that calling .mean() might be a weak solution, in case you have batch sizes that are not equally divided across the GPUs. The proper solution seems to be summing all outputs for the loss, and then dividing with the total batch size (after scaling the individual losses by their local batch sizes as well).

3 Likes

I found that doesn’t suppress the warning. Before returning the value from within DataParallel I unsqueezed the scalar along dim 0 and that did the trick.