I’m trying to figure out how to utilize torch.max() as the condition for torch.where(). My goal is to zero-out the non-max values of a 4D-tensor across the channels of each sample. For example:
# given:
a = torch.arange(3*2*2).view(1, 3, 2, 2)
>>>tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]]])
# wanted:
tensor([[[[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 0, 0]],
[[ 8, 9],
[10, 11]]]])
I’m trying to use some form of: torch.where(condition=a.max(dim=1), a, torch.zeros_like(a))
, but I can’t quite figure out how to get it to work.
Does anyone know a nice approach to do this?