The solution of @ptrblck is the best I think (because the simplest one).
For the fun, you can also do the following ones:
# create a function (this my favorite choice)
def RMSELoss(yhat,y):
return torch.sqrt(torch.mean((yhat-y)**2))
criterion = RMSELoss
loss = criterion(yhat,y)
# create a nn class (just-for-fun choice :-)
class RMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self,yhat,y):
return torch.sqrt(self.mse(yhat,y))
criterion = RMSELoss()
loss = criterion(yhat,y)