Basic question - need help with F.nll_loss

I have the following model and training method. In simple terms this what nll_loss does right? it takes set of output from neurons and it picks out the winning neuron (index of the neuron).

class AND(nn.Module):
    def __init__(self, input_count):
        super(AND, self).__init__()
        assert(input_count >= 2)
        self.linear = nn.Linear(input_count, 2)
       
    def forward(self, x):
        return F.log_softmax(self.linear(x))
        

from itertools import product
def and_gate(*inputs):
    for i in inputs:
        if i == 0:
            return 0
    else:
        return 1
  

def truth_table_for_and(size=2):
    table = []
    for sample in product([0, 1], repeat=size):
        table.append([list(sample), and_gate(*sample)])
    return table   

‚Äč

Example output of truth table:
pprint(truth_table_for_and(2))

[[[0, 0], 0], [[0, 1], 0], [[1, 0], 0], [[1, 1], 1]]

and2 = AND(2)
truth_table2 = truth_table_for_and(2)
def train(epochs, print_every=10):
    optimizer = optim.SGD(and2.parameters(), lr=0.1, momentum=0.01)
    
    for epoch in range(epochs):
        for i, o in truth_table2:
            i , o = Variable(torch.Tensor([i])), Variable(torch.Tensor([o]))
            o_ = and2(i)

            optimizer.zero_grad()
            print(o_.size(), o.size())
            loss = F.nll_loss(o_, o)
            loss.backward()
            optmizer.step()
            
        if epoch % print_every == 0:
            print('loss: {}'.format(loss.data[0]))   

train(100)

torch.Size([1, 2]) torch.Size([1]) <---- output of print(o_.size(), o.size())

In the F.nll_loss function, i get the following error. I have pondered for a while, but could not get it to work.

TypeError: FloatClassNLLCriterion_updateOutput received an invalid combination of arguments - got (int, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor), but expected (int state, torch.FloatTensor input, torch.LongTensor target, torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight)

Got it. Used FloatTensor inplace of LongTensor for target values.