Bckpropagation error through scatter_

When I use .scatter_() as:

cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # phi = cos(theta + m)

one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 

gradient vanishing problem happens in

output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 

so how can I do this correctly?