I have: array of size [batch x k] where k is a k-class distribution
[ [0.1, 0.2, 0.7], [0.1, 0.8, 0.1], [0.4, 0.5, 0.1] ] (batch and k are both 3)
I want a sample according to the probability distribution of the k-classes per batch
A possible sampled output:
[2, 1, 0]
I do not expect to take the gradient against this sample, this is purely for visualization and in test-time after the model has been trained already.