Hi guys I am trying to build a large classification model with a huge number of classes like training a word embedding. At the end of the model there is a linear layer with like 1 million outputs, followed by a softmax layer.

What I do is splitting the large linear layer into 8 independent layers with their weights on 8 GPUs respectively. The softmax scores can computed on the same GPUs through gathering sums. Then, what is the best way to do the backward?

My current solution is to implement a CrossEntropyLoss method manually by picking the softmax probabilities from different GPUs and do backward on this loss. This seems quite inefficient. Since computing the gradients can be quite simple, is it possible to feed the gradients directly rather than executing a ‘loss.backward()’?