Implement Custom Loss

Can someone help me implement cluster assignment hardening loss (Page 3) implemented in this paper : https://openreview.net/pdf?id=B1CEaMbR- in pytorch ?

I want to this loss combined with MSE loss in my networl.

I am clueless on how to implement this ?