How to implement batch l2 normalization with pytorch

hey guys, Iā€™ m new to pytorch, I just want to know is there any pytorch API that
can process the tensor with l2-normalization? In tensorflow, the corresponding API is tf.nn.l2_normalize.

1 Like

I think I just got the answer.

  • import torch.nn.functional as f
  • a = torch.randn(2,3)
  • norm_a = f.normalize(a,dim=0,p=2)

where p=2 means the l2-normalization, and dim=0 means normalize tensor a with row.

4 Likes

Thanks for the code. It should be -1 instead of 0 I think.
norm_a = f.normalize(a,dim=-1,p=2)