Sum over indices with value (1)01

Hi,

I have a tensor with a shape (2, 2, 3) like:
a= tensor ( [ [ [2, 0, 2],
[1, 0, 0] ],
[ [1, 0, 1],
[0, 1, 0] ] ] )

I want to find the indices of the values (1), then make 2 to the power of those indices and finally add the results for the last dimension, so the final result should be like:

tensor ( [ [ [ ],
[2^0]],
[ [2^0+2^2],
[2^1] ] ] )

My actual tensor is much bigger than this example, so I don’t want to use “for” loop, and I have to use broadcasting…

I was thinking of something like torch.pow(2,(a == 1).nonzero()).sum(), but it doesn’t work. I have to find a way to apply (a== 1).nonzero() only for the last dimension, any suggestion? Thanks.