Loss for each sample in batch

i want to extract the value of loss for each sample in a training/testing batch. how to get this more efficiently ?

should i use this method below :

  1. call loss function two times
loss_fn = nn.MSELoss( )
loss_all = loss_fn (input, target) 
loss_each =  torch.mean( loss_fn (input, target).detach(),1 ) 
loss_all.backward() # this loss used for backward 
other_computation( loss_each) # not used for backward

thanks very much

1 Like

loss_all should be a scalar, since nn.MSELoss will compute the mean by default.
The third line of code should therefore yield the same result.
If you want the loss value for each sample, you could specify reduction='none' for the criterion and call torch.mean for loss_all:

loss_fn = nn.MSELoss(reduction='none')
input = torch.randn(10, 1, requires_grad=True)
target = torch.randn(10, 1)
loss_each = loss_fn(input, target) 
loss_all =  torch.mean(loss_each)
loss_all.backward()
other_computation(loss_each.detach())
4 Likes