Custom Loss assert not torch.is_tensor(other) error

Hi

I’m trying to write a custom Loss function.

However, there seems to be something I’m not getting right… For now, this is just a dummy example.

import torch
import torch.nn as nn
from torch.autograd import Variable

class BCELossReg(torch.nn._functions.thnn.BCELoss):

    def __init__(self, ratio, size_averaged=True):
        super(BCELossReg, self).__init__(size_averaged)
        self.ratio = ratio

    def forward(self, input, target, n):
        result = super(BCELossReg, self).forward(input, target)

        result = self.ratio * result + (1-ratio) * result * n

        return result


# init
model = nn.Sequential(
    nn.Linear(5,3),
    nn.ReLU(),
    nn.Linear(3,1),
    nn.Sigmoid()
)
r = Variable(torch.Tensor([0.9]), requires_grad=False)
n = Variable(torch.Tensor([2]), requires_grad=False)
loss_f = BCELossReg(r)
optim = torch.optim.Adam(model.parameters(), lr=0.001)
optim.zero_grad()

# model forward pass
data = Variable(torch.Tensor([1,2,3,4,5]).view(1,-1))
prediction = model(data)
target = Variable(torch.Tensor([1]))

# loss and backward
loss = loss_f(prediction, target, n)
loss.backward()
optim.step()

When I’m executing this I’m getting the following error:
Traceback (most recent call last):
File “BCELossReg.py”, line 38, in
loss = loss_f(prediction, target, n)
File “BCELossReg.py”, line 14, in forward
result = self.ratio * result + (1-ratio) * result * n
File “/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torch/autograd/variable.py”, line 761, in mul
return self.mul(other)
File “/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torch/autograd/variable.py”, line 305, in mul
assert not torch.is_tensor(other)
AssertionError

1 Like

Somewhere your multiplying a Variable by a tensor when you need to multiply a Variable by a Variable. It’s not clear from your snippet where that is.

This looks like a bug:

result = self.ratio * result + (1-ratio) * result * n

Instead of (1-ratio) you should have (1-self.ratio)