Hello everyone!
I’m currently working in a regression problem where my model output and my targets are of size 1x5 and I want to implement a weighted RMSE loss, that works the same as PyTorch’s built in nn.MSELoss when no weights are provided. I’m showing it here in case it’s useful for anyone
class WeightedRMSELoss(nn.Module):
def init(self,weights=None):
super(WeightedRMSELoss, self).init()
self.weights=weights #Initialize weights
def forward(self, y_pred, y_true, weights=None):
if weights is None: #We build the class this way so you can call the weights when creating the loss or when applying it.
if self.weights is None: #No input weights when first creating the class.
weights = torch.ones_like(y_true)*1/y_true.size(1) # Default to ones if no weights are provided
else: #If a weights vector is provided.
weights = self.weights/torch.sum(self.weights) #Make sure our weights vector is normalized so the sum of all elements is equal to 1.
# Calculate squared errors
squared_errors = (y_pred - y_true) ** 2
# Apply weights to squared errors
weighted_squared_errors = squared_errors * weights #By making sure our weights vector is normalized this always works, if no weights are provided this is identical to using RMSE.
# Take square root to get RMSE
rmse_single = torch.sqrt(torch.sum(weighted_squared_errors,dim=1)) #dim=1 makes sure we're calculating one rmse value for each row.
rmse=torch.mean(rmse_single) #Now we're reducing over the batch dimension.
return rmse
I compared it to built-in mseloss by creating two random (1,n) vectors with n as an iterator and calculating my custom function loss as WeightedRMSELoss(y_pred,y_true) and PyTorch’s loss as torch.sqrt(nn.MSELoss(y_pred,y_true)). Note that this function isn’t supposed to give an equal result when matrices are inputs, this is because my loss function reduces over the batch dimension after taking RMSE values, whereas approaching it as torch.sqrt(nn.MSELoss(x,y)) reduces MSE values and then applies the squared root. The weights vector has to have a shape of (1,features) for input vectors of shape (batch size, features) in the forward method.
I hope this can help somebody! Any comments or feedback are appreciated.