I am trying to build a logistic regression using Pytorch framework…
I want Binary logistic regression with “m” parameters.
Hence I thought dimension of network would be [m,2] ( 2 for two binary c;asses)
But it is giving 2*m parameters or weights.
When i changed the dimension to [m,1] , I am getting errors while running the learning program.as follows:
RuntimeError Traceback (most recent call last)
in ()
26 outputs = model(data_X)
27 labels = labels.squeeze_()
—> 28 loss = criterion(outputs, labels)
29 loss.backward()
30 optimizer.step()
3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
–> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
914 def forward(self, input, target):
915 return F.cross_entropy(input, target, weight=self.weight,
–> 916 ignore_index=self.ignore_index, reduction=self.reduction)
917
918
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
1993 if size_average is not None or reduce is not None:
1994 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1995 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
1996
1997
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1822 .format(input.size(0), target.size(0)))
1823 if dim == 2:
-> 1824 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
1825 elif dim == 4:
1826 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes’ failed. at /pytorch/aten/src/THNN/generic/ClassNLLCriterion.c:94