I found the reason, it is because we collect the output back to one gpu and calculate loss there. If move loss calculation into model.forward(), the problem is resolved.
I found the reason, it is because we collect the output back to one gpu and calculate loss there. If move loss calculation into model.forward(), the problem is resolved.