Loss function for differentiable programming (targets computed from y_pred)

My process looks similar to this: https://fluxml.ai/assets/2019-03-05-dp-vs-rl/trebuchet-flow.png

I have some features and some labels. From these labels, i can invoke a custom function to check whether some target criteria match. I don’t care about the labels per se, only that the target criteria are met. Now, similar to the Trebuchet example, my ANN should infer the labels from the features.

My idea is to write a custom loss function to use the predicted features and then check the target criteria. The sum of the target criteria differences is the loss for the neural network.

In principle, the code works, but the model is not learning (loss is exactly the same in every epoch). Likely i’m breaking the graph by converting the labels to numpy, which i have to do in order to calculate the targets.

Does anyone know how to handle this “numpy interruption” (without writing a custom backward)?

Loss:

def solve_targets(x, y, ...)
    return some_numpy_calculations(x, y, ...)

class FLoss(nn.Module):
    def __init__(self, t, g, xsc, dvc):
        super(FLoss, self).__init__()
        self.targets = t
        self.device = dvc
        ...

    def forward(self, x, y):
        loss_np = solve_targets(x.cpu().data.numpy(), y.cpu().data.numpy())
        res_values = torch.tensor(loss_np, requires_grad=True, device=self.device)
        loss = torch.clamp_min_(res_values - self.targets, 0.).sum()
        print("loss", loss.item())
        return loss

Hi,

You should never use .data as a general rule :slight_smile:

In this case, it break the computational graph and prevent gradients from being computed.
Also the autograd won’t work if your perform your ops on numpy Tensors I’m afraid. You will either have to convert your code to use pytorch functions (which should be the simplest) or write a custom backward.

Thanks!

Unfortunately coding purely in torch is not possible, also computing the gradients seems to be impossible.
Actually, an optimization problem is solved in the numpy function.

Well if there are no gradients, there is not much we can do :smiley:

For the optimization problem, you have two choices:

  • differentiate the whole process by rewriting it in pytorch
  • Find some conditions verified by the solution and input (like KKT) and use that to write a custom backward that will compute the gradients.

Hmm yeah… thanks!

Another idea was to approximate the optimization with a second neural network, but that’s just an idea which makes it much more complicated.

Hi Jan-Hendrik!

It is not unreasonable to compute your gradients numerically
(although suboptimal if you can compute them analytically).

So, in your forward pass you might compute:

    loss_np = numpy_stuff (x, y)
    <loop over x_i>:
        xdelt_i = x_i + delta
        grad_i = (numpy_stuff (xdelt, y) - loss_np) / delta
    ...

Then store (or accumulate) your grad_i to use in a custom
backward() function.

This is just a superficial sketch, and numerical differentiation is
more delicate than indicated above, but the main point is that
numerical differentiation can be a legitimate way to implement
backward().

Note, that some optimization problems can be solved more
quickly if you give them a good initial guess. So you might be
able to speed things up by reusing your initial solution when
re-solving your optimization multiple times for the numerical
differentiation.

(And, as Alban noted, some optimization algorithms naturally give
you – or can be relatively cheaply modified to give you – the
gradients you will need for your custom backward().)

But yes, just to reiterate what Alban said: You will have to
implement a custom backward() one way or another (or
implement your custom loss function entirely with differentiable
pytorch tensor operations).

Best.

K. Frank

1 Like

Thank you K. Frank for the good suggestions!
I guess i can use finite differences for gradient approximation at the cost of longer runtimes, but it’s perfect for quickly validating my idea!

The optimization problem can theoretically be solved in torch, but that requires rewriting LOTS of code for just a simple experiment