How to extract the largest value of a tensor while keep it's origin position

I have a tensor:

a = torch.Tensor([[0.0, 0.9, 0.8, 0.0, 0.2, 0.0],
                  [0.2, 0.0, 0.0, 0.8, 0.7, 0.0],
                  [0.6, 0.0, 0.0, 0.0, 0.0, 0.7]])

and I want to extract the max value of dim=1, and keep it’s original position, like this:

res = torch.Tensor([[0.0, 0.9, 0.0, 0.0, 0.0, 0.0],
                  [0.0, 0.0, 0.0, 0.8, 0.0, 0.0],
                  [0.0, 0.0, 0.0, 0.0, 0.0, 0.7]])

How to implement this?

Thank you!

This should do the trick

import torch
a = torch.Tensor([[0.0, 0.9, 0.8, 0.0, 0.2, 0.0],
                  [0.2, 0.0, 0.0, 0.8, 0.7, 0.0],
                  [0.6, 0.0, 0.0, 0.0, 0.0, 0.7]])
b = torch.argmax(a,dim=1).unsqueeze(1)
one_hot = torch.zeros(a.shape).scatter(1,b,1)
req_op = a * one_hot