Masked argmax in Pytorch?

Thanks. I need to convert m to tensor then.