Training with threshold in PyTorch

Hi @learner47 I don’t think x can be differentiable here. Maybe try modifying the way you compute the threshold. Even then, your gradient is going to be 1 where data_array>=x and 0 elsewhere. Maybe this discussion might help clarify things How to make the parameter of torch.nn.Threshold learnable?

Best!