Solving partial derivatives of NN outputs using retain_graph = True and create_graph = True, then optimizing over both NN output and corresponding derivatives

I am trying to solve a PDE using initial values and knowledge on to PDE to construct the losses. It is more or less copied from a tensorflow code, yet I get different results, and so I am inclined to believe that something weird is going on when backpropogating across the same graph multiple times, then optimizing accross these as well.

In the neural net:

        x.grad = None
        u.backward(torch.ones((batch_size, 1)), retain_graph=True, create_graph=True)
        u_x = x.grad

        # retain_graph : allows backprop through same variable again all derivatives need this
        # create graph : makes it so that x.grad has grad_fn

        #x.grad = None
        u_x.backward(torch.ones((batch_size, 1)), retain_graph=True, create_graph=True)
        u_xx = x.grad

        x.grad = None
        v.backward(torch.ones((batch_size, 1)), retain_graph=True,  create_graph=True)
        v_x = x.grad

        # x.grad = None
        v_x.backward(torch.ones((batch_size, 1)), retain_graph=True, create_graph=True)
        v_xx = x.grad

        t.grad = None
        u.backward(torch.ones((batch_size, 1)), retain_graph=True, create_graph=True)
        u_t = t.grad

        t.grad = None
        v.backward(torch.ones((batch_size, 1)), retain_graph=True, create_graph=True)
        v_t = t.grad

        f_u = u_t.float() + 0.5 * v_xx.float() + (u ** 2 + v ** 2) * v

        f_v = v_t.float() - 0.5 * u_xx.float() - (u ** 2 + v ** 2) * u

The problem continues here (posted by accident):

My loss function is a function of many variables including the derivatives (MSELoss is used inside the function):

loss = loss_function(u_0.double(), u_0_target, v_0.double(), v_0_target, u_lb, u_rb, v_lb, v_rb, u_x_lb, u_x_rb, v_x_lb, v_x_rb, f_u, f_v)

then the optimizing step:

I found that the model is able to fit the loss w.r.t. the variables u and v, but fails when trying to fit f_u and f_v which is a function of first and second derivatives.

I hope someone can notice a mistake because I am running out of ideas.



Few notes to improve your code:

  • If you set create_graph=True, you don’t need to pass the retain_graph arg.
  • When playing with higher derivatives, we recommend using autograd.grad instead of backward. That way you won’t have to play with the .grad fields and potentatially making mistakes. (in particular in your code, a bunch of .grad reset are commented out which might lead to unexpected behavior when you compute v_xx for example). You can do v_x = autograd.grad(v, x, torch.ones((batch_size, 1)), create_graph=True)[0].

Does that solve your issue?

Thanks for the response and the notes. Although this really does help as far as clarity (I was getting really uncertain with my method) it still does not solve me problem. This leads me to believe that they error my lie else where. Perhaps take a look at my loss function:

def loss_function(u_0, u_0_target, v_0, v_0_target, u_lb, u_rb, v_lb, v_rb, u_x_lb, u_x_rb,
v_x_lb, v_x_rb, f_u, f_v):

loss = nn.MSELoss()
return loss(u_0, u_0_target) + loss(v_0, v_0_target) + loss(u_lb, u_rb) + \
          loss(v_lb, v_rb) + loss(u_x_lb, u_x_rb) + loss(v_x_lb, v_x_rb) + \
          loss(f_u, torch.zeros((f_u.shape[0], 1))) + loss(f_v, torch.zeros((f_v.shape[0], 1)))

The NNs task is to minimize each of these individual pairs (the f_u and f_v are desired to be 0 hence the torch.zeros). I have previously plotted the progress of each individual loss, and it could clearly be seen that every loss converged to zero, except the loss of the two f’s which stay relatively constant. Do you have any idea how this might be possible?

Another thing that might be worth noting is that to find the individual inputs to the loss function I had to run the NN four times:

u_0, v_0, _, _, _, _ = pinn(batch_input_0)
u_lb, v_lb, u_x_lb, v_x_lb, _, _ = pinn(batch_input_lb)
u_rb, v_rb, u_x_rb, v_x_rb, _, _ = pinn(batch_input_rb)
_, _, _, _, f_u, f_v = pinn(batch_input_f)

though I doubt pytorch has issues with this.

Your help is very much appreciated.

Running the Module multiple times should be perfectly fine.

At this point, I would recommend using torchviz.
You can use it to plot the graph associated with one or more Tensors.
In particular here, you can make sure that everything is connected as you expect to the right place.
Note that for readability if your pinn model is large, you can replace it with a simpler version with less layers.

Great tip, I’ll try it out.

Thanks for your help!