Binary classifier: RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed

I’m trying to create a basic binary classifier that classifies whether my player plays in the right or the left side in the game Pong but I keep getting errors. Each time I get one 1x42x42 image and the side (right = 1 or left = 2). The code:

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

net = Net(42 * 42, 100, 2)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = torch.optim.Adam(net.parameters(), 0.001)
net.train()

while True:
    state = get_game_img()
    state = torch.from_numpy(state)

    # right = 1, left = 2
    current_side = get_player_side()
    target = torch.LongTensor(current_side)
    x = Variable(state.view(-1, 42 * 42))
    y = Variable(target)
    optimizer_net.zero_grad()
    y_pred = net(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()

The error I get:

  File "train.py", line 109, in train
    loss = criterion(y_pred, y)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward
    self.weight, self.size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy
    return nll_loss(log_softmax(input), target, weight, size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss
    return f(input, target)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
    output, *self.additional_args)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57

Did you find a workaround for this issue?

I’m getting the same error. I worked around by increasing num_classes > 1. So in my case if I’m training a network to detect just digit 5 in mnist, I still output 10 classes but my labels are 0-1. where 1 is for digit 5 and 0 is for all other digits.

Values of target should be in the range [0, cur_target) for CrossEntropyLoss. Mapping your target to {0, 1} does work, yes.

Changed the loss function to BCELoss

1 Like