so, outputs = torch.argmax(outputs, 1, keepdim=True) ?
outputs = torch.argmax(outputs, 1, keepdim=True)