How to do scaled L2 distance calculation in pytorch?

Hello,

I wanted to do traceable scaled L2 distance calculation in Pytorch: Something like this:
A_ik = G_k (X_i - Y_k)^2, where X is in format of (B * N * D), Y in format of (B * K * D), and G in format of (B * K) so the output A is in format of (B * N * K). Does anyone know that how I can do that? Thanks in advance

Did you check pytorch einsum?