How to compute the joint-histogram in pytorch?

Hi jiawei_chen,
Did u find any solution? Any help will be appreciated.

Thank you