I had to implement something similar. My approach was the following (where mask is a tensor of 1s and 0s indicating the entries to be removed):
def masked_softmax(vec, mask, dim=1):
masked_vec = vec * mask.float()
max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
exps = torch.exp(masked_vec-max_vec)
masked_exps = exps * mask.float()
masked_sums = masked_exps.sum(dim, keepdim=True)
zeros=(masked_sums == 0)
masked_sums += zeros.float()
return masked_exps/masked_sums