Torch.autograd.grad() and backward() call to compute the gradients

I’m working with Physics Informed Neural Networks and establish the physics by computing the derivatives with the torch.autograd.grad() function. I only calculate the first order derivatives in my framework. After that I try to compute the total loss and call the loss.backward() method on the total loss, which computes the gradients of the loss with respect to all the tensors that have the requires_grad property set to true (involves the weights and biases and also the network inputs with which the network outputs are computed). Am I right?
My question is that what should be the values for options retain_graph and create_graph in the torch.autograd.grad function and how does these values affect the backward() call? If it affects, how does it affect? Can anyone shed light on the same please?

Hi @Aakhash_Sundaresan,

Do you have a minimal reproducible example for your problem?

If you want to compute higher-order gradients, you need to make sure retain_graph=True and create_graph=True as PyTorch will automatically free the gradients of previous operations (unless you tell it otherwise).

def PDE_loss(self, pred, target):
            xy = target
            u, v, p = pred[:, 0:1], pred[:, 1:2], pred[:, 2:3]
            #first order derivative
            u_x = torch.autograd.grad(u, xy, torch.ones([u.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
            u_y = torch.autograd.grad(u, xy, torch.ones([u.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
            v_x = torch.autograd.grad(v, xy, torch.ones([v.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
            v_y = torch.autograd.grad(v, xy, torch.ones([v.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
            p_x = torch.autograd.grad(p, xy, torch.ones([p.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
            p_y = torch.autograd.grad(p, xy, torch.ones([p.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
            #second order derivative
            u_xx = torch.autograd.grad(u_x, xy, torch.ones([u_x.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
            u_yy = torch.autograd.grad(u_y, xy, torch.ones([u_y.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
            v_xx = torch.autograd.grad(v_x, xy, torch.ones([v_x.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 0:1]
            v_yy = torch.autograd.grad(v_y, xy, torch.ones([v_y.shape[0],1]).to(device),retain_graph=True, create_graph=True)[0][:, 1:2]
            #continous equation
            f0 = u_x + v_y
            #momentum equation
            f1 = u*u_x + v*u_y + p_x - 1/self.Re*(u_xx + u_yy)
            f2 = u*v_x + v*v_y + p_y - 1/self.Re*(v_xx + v_yy)

            mse_f0 = self.MSE_loss(f0, torch.zeros_like(f0))
            mse_f1 = self.MSE_loss(f1, torch.zeros_like(f1))
            mse_f2 = self.MSE_loss(f2, torch.zeros_like(f2))

            return mse_f0 + mse_f1 + mse_f2


    def closure(self):
        total_batch_loss = 0.0
        col_batch_loss = 0.0
        bc_batch_loss = 0.0
        outlet_batch_loss = 0.0

        self.model.train()

        [col_data, bc_data, outlet_data] = self.data[0], self.data[1], self.data[2]
        for i, (col_data, bc_data, outlet_data) in enumerate(zip_longest(col_data, bc_data, outlet_data)):
            total_loss = torch.tensor(0.0).to(device)
            self.optimizer.zero_grad()
            #unpack data
            col_xy = col_data
            bc_xy, bc_uv = bc_data[0], bc_data[1]#
            out_xy, out_p = outlet_data[0], outlet_data[1]

            #forward pass
            if col_xy is not None:
                col_xy = col_xy.clone()
                col_xy.requires_grad = True
                col_pred = self.model(col_xy)
                col_loss = self.PDE_loss(col_pred, col_xy)
                total_loss += col_loss
            if bc_xy is not None:
                bc_pred = self.model(bc_xy)
                bc_loss = self.BC_loss(bc_pred, bc_uv)
                total_loss += bc_loss
            if out_xy is not None:
                out_pred = self.model(out_xy)
                outlet_loss = self.Outlet_loss(out_pred, out_p)
                total_loss += outlet_loss



            total_loss.backward()
            if self.optimizer.__class__.__name__ != "LBFGS":
                self.optimizer.step()
                
            
            total_batch_loss += total_loss.detach().cpu()
            col_batch_loss += col_loss.detach().cpu()
            bc_batch_loss += bc_loss.detach().cpu()
            outlet_batch_loss += outlet_loss.detach().cpu()

            



        self.losses["total"].append(total_batch_loss)
        self.losses["pde"].append(col_batch_loss)
        self.losses["bc"].append(bc_batch_loss)
        self.losses["outlet"].append(outlet_batch_loss)

In the code above, I don’t see a rationale behind using the retain graph and create graph when calculating the second order gradients. I understand that these options need to be switched on when computing the first order gradients, to be utilized when computing the second order gradients.

When computing second-order gradients, you’ll need make sure the first-order gradient call has retain_graph=True and create_graph=True as you mentioned. You won’t need them for the second-order gradients as you’re not computing any higher-order gradients.

The docs of torch.autograd.grad can be found here: torch.autograd.grad — PyTorch 2.3 documentation

When using torch.autograd.grad it will return a Tensor, which represents the gradient of outputs with respect to the inputs. When calling .backward() it will compute gradients in an accumulated fashion on all Tensors that have a .grad attribute. This is why torch.optim.Optim objects, e.g. torch.optim.Adam, require you to call optim.zero_grad() before computing loss.backward(), because if you don’t the .grad attribute of your tensors will have the gradient of the current epoch and the previous epoch.

Thanks a ton for your timely support. Also, I had this another doubt, why do we require to use the retain_graph in the loss.backward(retain_graph = True) call. What is the rationale behind doing this?

I don’t think you need to have retain_graph=True within loss.backward(), if the interpreter is throwing an error asking you to do loss.backward(retain_graph=True), perhaps it might be due to mini-batching your gradients? But I’m not 100% sure.

Suppose, imagine that i compute the gradient of the output of the network using torch.autograd.grad(), with respect to the input, to design the PDE residual loss. And then when I call the loss.backward(), does the gradients get accumulated in the tensors that result from the torch.autograd.grad() call? If so, do I need to use the create_graph = True option in the torch.autograd.grad() call?

If I understand you correctly, you’re asking if the .grad attribute is populated on the PDE residual loss terms (which are calculated via the torch.autograd.grad calls). Would this be the u_xx, u_yy etc. terms?

I don’t think this would be an issue as the only gradient you care about is with respect to the parameters, and you could check if these intermediate have .grad attributes by calling u_xx.grad?

You’re performing higher-order derivatives, so I think you should need to keep create_graph=True. If you’re unsure, you can always run your code with create_graph=True and create_graph=False and see if you get the same results.