HI,
I am defining an customer operator, with backward:
class LabelSmoothSoftmaxCEFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, label, lb_smooth):
num_classes = logits.size(1)
scores = torch.softmax(logits, dim=1)
logs = torch.log(scores)
lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)
label = (1. - lb_smooth) * lb_one_hot + lb_smooth / num_classes
ctx.lb_smooth = lb_smooth
ctx.num_classes = num_classes
ctx.scores = scores
ctx.lb_one_hot = lb_one_hot
loss = -torch.sum(logs * label, dim=1)
return loss.mean()
@staticmethod
def backward(ctx, grad_output):
print(ctx.saved_variables)
lb_smooth = ctx.lb_smooth
num_classes = ctx.num_classes
scores = ctx.scores
lb_one_hot = ctx.lb_one_hot
loss = scores - lb_smooth / num_classes
# loss[lb_one_hot == 1] = -loss[lb_one_hot]
loss = torch.where(lb_one_hot == 1, -loss, loss)
return loss, None, None
class LabelSmoothSoftmaxCEV2(nn.Module):
def __init__(self, lb_smooth):
super(LabelSmoothSoftmaxCEV2, self).__init__()
self.lb_smooth = lb_smooth
def forward(self, logits, label):
return LabelSmoothSoftmaxCEFunction.apply(logits, label, self.lb_smooth)
When I tried to print the loss, I got:
loss = criteria(logits, label)
print(loss.item())
()
8.125082015991211
()
8.200040817260742
Why is there a ()
when I print the value, how could I make it work as other loss?