Softargmax in pytorch (in 2D)

Softargmax is used quite many place. Deep Spatial Autoencoders for Visuomotor Learning probably introduced it.

Is there any function/layer in pytorch that performs it or any custom implementation.

I implemented it on following idea :

# consider input [N,C,H,W] (batch, channels, height, width)
temp = input.view(N,C,-1)
m = torch.softmax(dim=-1)
weights = m(temp)
semi_indices = (weights*(torch.arange(H*W).unsqueeze(0).unsqueeze(0).expand(weights.size()))).sum(dim=-1)
indices_x = semi_indices / H
indices_y = semi_indices % W
# maths could be off a bit but basically somewhat like this

But I really didnt like this implementation, please suggest something (maybe using Softmax2d but it similar i guess). Would there be any other way to do this please suggest!

2 Likes
2 Likes