RuntimeError: derivative for bincount is not implemented

Hello torch gurus,

I need a backpropable implementation of bincount and looks like torch.bincount is not.

I am not sure how torch.bincount is implemented, is there any efficient alternative implementation of bincount (or work around) that I can backbrop through? My own vanilla implementation of bincount in torch is nowhere near as fast as torch.bincount for large arrays (~50k length).

Thanks