Continuing the discussion from Custom loss functions: In this solution, @ptrblck demonstated custom loss function. However, the function my_loss
returns a torch tensor. Then how come running .backward()
on this tensor works in pytorch?
In my case, my loss function looks like this :
def my_loss(out, tar):
loss = torch.sum(out * tar)
return loss
Now this returned loss does not have any parameter or a method named backward()
, then how can I run .backward()
on the torch tensor returned by loss = my_loss(out, tar)
?