Hi, I am trying to write custom loss function by creating a nn.module subclass with forward method.
Is it possible to use pytorch reduction mode handling in my loss?
Sure I can do:
if reduction == 'mean':
return torch.mean(my_loss_value)
elif ...
But I would rather just call something like:
return torch_reduction_handling_function(my_loss_value, reduction=reduction)
However, I don’t know where to find this reduction handling function and if it is exposed to me at all.
Thanks for your help!