Custom loss function error reporting

So I want to report my regression error in a particular fashion (RMSE per attribute of data sample). I can go about doing this in two ways:

  1. Creating a custom loss function that does this and therefore extracting the particular value from `loss=optimizer.step() such that my custom loss is as such:

class CustomLoss(nn.Module):
def init(self):
super(CustomLoss,self).init()

 def forward (self,output,target,attribute):
       output_per_attribute=torch.div(output,attribute)
       target_per_attribute=torch.div(target,attribute)
       lossfn=nn.MSELoss()
       loss=torch.sqrt(lossfn(output_per_attribute,target_per_attribute))
       return loss

`

  1. Using a MSE Loss function and then including lines to go through the data and calculate the specific loss accordingly:
optimizer.zero_grad()
output=model(input)
loss=MSELoss(output,target)

with torch.no_grad():
      output=model(input)
      output_per_attribute=torch.div(output,attributes)
      target_per_attribute=torch.div(target,attributes)
      desired_loss = MSELoss(output_per_attribute, target_per_attribute)
      print(desired_loss)

loss.backward()
optimizer.step()

My question is, is there an optimal way to approach this? I tried both scenarios and noticed that training with the first method seemed to be slower - possibly because the loss function values are smaller and hence gradients are smaller leading to smaller steps? Thanks!

If you’ve just wrapped the loss function using nn.MSELoss into your custom class, both approaches should work identically.
Is the slower training reproducible?

Also, in your second script it looks like you are not initializing nn.MSELoss but just pass the input and target values directly. Is this a typo pasting the code?

Could you check the shapes on input and target using both approaches?
Make sure both shapes are equal, as otherwise an unwanted broadcasting might happen which could slow down or stop your training.

Thanks! Yeah I had initialized my loss function accordingly, just saved time pasting the code. I can double check the shapes on the input and targets but I believe there was no issue there as well. But if from a higher level view this should work the same, then thats all I needed to confirm.