torch.argmax
is not differentiable and will thus break the computation graph.
Also, nn.BCELoss
expects probabilities as the model output and is used for binary or multi-label classification use cases so you should use a sigmoid
in this case.
torch.argmax
is not differentiable and will thus break the computation graph.
Also, nn.BCELoss
expects probabilities as the model output and is used for binary or multi-label classification use cases so you should use a sigmoid
in this case.