How to do weight normalization in every forward pass

Hi all,
I am trying to implement cosFace loss (Large Margin Cosine Loss), in this loss I need to normalize the weight of last Linear layer in the module every forward pass.

I have checked several related posts, like here.
But in my case, the normalization operation should have gradients and backpropagation.

So I might use some code like:

fc.weight = nn.Parameter(F.normalize(fc.weight, p=2, dim=1))

I just wanna know those cast operation(cast Tensor to Parameter) would affect the gradient or not?
Or there is another way to deal with it?

Thanks!

Hey!
Did you manage to solve this problem?

I am in a similar situation.