The for loop might indeed be slow.
Could you try to use direct indexing like in this example?
Currently the weights are passed to torch.multinominal
, so I guess there is no easy alternative at the moment. I get your issue and like the idea of just providing class weights. Let me think about a good approach.