Custom loss function cannot rturn multiple values?

Hi, so if I try to return multiple values from a custom loss function, it throws error that the loss function is not iterable.
Demo Exaple

import torch
import torch.nn as nn
from torch.autograd.function import Function

class CenterLoss(nn.Module):
    def __init__(self, num_classes=10, feat_dim=2, size_average=True):
        super(CenterLoss, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.centerlossfunc = CenterlossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, label, feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
        loss = self.centerlossfunc(feat, label, self.centers)
        loss /= (batch_size if self.size_average else 1)
        return loss,feat

b,c = CenterLoss()

**TypeError Traceback (most recent call last)
in ()
----> 1 b,c = CenterLoss()

TypeError: ‘CenterLoss’ object is not iterable**

You would have to create an object of the CenterLoss class first and then call pass it inputs.
The below code should work.

loss = CenterLoss()
b, c = loss(label, feat)

This doesn’t works…

Can you provide what error you get?

I am able to run the below code fine. As I don’t have access to the rest of your code I just return the input and feature as it is.

import torch
import torch.nn as nn
from torch.autograd.function import Function


class CenterLoss(nn.Module):
    def __init__(self, num_classes=10, feat_dim=2, size_average=True):
        super(CenterLoss, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))

    def forward(self, label, feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)      
        return label, feat


loss = CenterLoss()
label = torch.ones(2)
feat = torch.ones(2)
print(loss(label, feat))

The output is

(tensor([1., 1.]), tensor([[1.],
        [1.]]))

I did work around, and instead of returning the other variable, I called it using class_name.feat.

any suggestions how to overcome the problem when loss after few epochs starts throwing nan.

Have a look at the new anomaly detection for the autograd engine.

1 Like