Mask tensor without losing gradients

Hello, I know that this might be trivial but I’ve been structurally with it for quite some time. I’m having a tensor of shape [batch, num_nodes, num_nodes] and I would like to mask it so that on the second dimension I replace the biggest number with 1 and the rest with 0. So far I’ve come to this: mask = (mask == mask.max(dim=2, keepdim=True)[0]) it works but i’m loosing the gradients. Any help it’s welcome.