Implementing weighted least square with torch.nn


This is my first uploading in this forum. I’m now trying to implement simple linear regression task, but it is slightly different from common ols problem.

I need a help to implement first term efficiently. Is there any module that help computing weighted least square?
My trial is

output = W.matmul(X) # k * D matrix
l = torch.diagonal((output-y).T.matmul(Sigma).matmul(output-y)).sum()

Thank you in advance,