How to perform chain rule gradient


I am trying to figure out where to start with implementing this and was hoping for some pointers. Sorry if I’m missing something obvious.

Instead of computing a “Standard” L2 cost function of the form

|| N(s) - d ||_2^2

Where N is the network, s is the network input and d is the expected output, I want to apply a chain rule cost function,i.e

|| F(N(s)) - d ||_2^2

Where F is a function that cannot be directly defined, but I can compute the gradient for it using other means, i.e there is an (unfortunately expensive) way to compute F*(F(s) - d), so that my training gradient should be F*(F(N(s)) - d) * grad(N(s))


I am not sure to understand exactly what you want.
But if you want to create a new function with custom forward (potentially inexact) and backward (potentially gradients of the true forward you wanted but couldn’t compute), you can check here how to do it.