I just had the error this week myself, the cause is that you have a function/method (probably the forward
method of one or your modules) that returns a scalar, like an accuracy or loss function for example.
See this issue for more details. This is just a warning, but it could mean that your code doesn’t work properly.
For example, my use case was a learnable loss which was shared across GPUs with DataParallel
, but returned only a scalar in the forward
function, which was then stacked in a vector and caused an error further on. The solution to this specific problem was to call loss = loss.mean()
, which is what should happen when reassembling the values together anyway.