Hi, so if I try to return multiple values from a custom loss function, it throws error that the loss function is not iterable.
Demo Exaple
import torch
import torch.nn as nn
from torch.autograd.function import Function
class CenterLoss(nn.Module):
def __init__(self, num_classes=10, feat_dim=2, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
loss = self.centerlossfunc(feat, label, self.centers)
loss /= (batch_size if self.size_average else 1)
return loss,feat
b,c = CenterLoss()
**TypeError Traceback (most recent call last)
in ()
----> 1 b,c = CenterLoss()
TypeError: ‘CenterLoss’ object is not iterable**