How to implement a loss that averages the coordinates of a heatmap weighted by its values

I want to implement a loss that extracts the maximum location of a heatmap. Since argmax is not differentiate, an alternative method is to calculate the average coordinates of the heatmap weighted by its values.

For example, [[0, 1], [3,2]] has a maximum value of 3 and its coordinate is (1,0), which can be approximately calculated as follows:

The loss takes a heatmap in the shape of BxNxWxHxD as input, outputting the coordinates in the shape of BxNx3.

I am not familiar with PyTorch and do not know how to implement such loss. I will appreciate it if someone can help me.