I am trying to implement a loss function that combines hinge loss with KL Divergence ( paper, Lua implementation).

The current implementation of kl_div() returns either the sum or mean of the values over the batch, but I need the exact values. Calling kl_div() on each pair separately is extremely slow (5-10x slower than calling it on the entire batch). any advice on what I need to change/implement to get a version of kl_div() to return a vector of values rather than sum or mean ?

Thank you!