Problem with computing the derivatives using autograd

I’m trying to compute the derivative of a tensor with respect to another tensor inside a function shown as below

def calculate_pde_momentum_loss(self,
                                    pred_u_XYf:torch.Tensor, 
                                    pred_T_XYf:torch.Tensor, 
                                    XYf:torch.Tensor 
                                    ) -> torch.Tensor:
        """
        Summary:
        -------
        Sub-routine to compute the residual loss from the governing equations concerning the Continuity, X-Momentum and Y-Momentum PDEs, at the collocation points in the computational domain.
         
        The pytorch's autograd version of calculating residual of the Governing differential equations (Momentum equations) of the Reynolds Averaged Navier Stokes (RANS)
        --> Net 1 - Output comprises of variables for Continuity, X-Momentum and Y-Momentum equations and their fluxes 
        1. U - u-velocity
        2. V - v-velocity
        3. P - Pressure
        4. u'u' - x-component Reynolds stress
        5. u'v' - xy-component Reynolds stress
        6. v'v' - y-component Reynolds stress
        7. sigma11 - ShearXX
        8. sigma12 - ShearXY
        9. sigma22 - ShearYY

        --> Net 2 - Output comprises of variables for 
        1. T - Temperature (Non-dimensionalized)
        2. u'T' - Turbulent Heat flux in X-direction
        3. v'T' - Turbulent Heat flux in Y-direction
        4. qx - Heat Flux in X
        5. qy - Heat Flux in Y

        Args:
        ----
            - `pred_u_XYf` (`torch.Tensor`): The output (predictions) from the Momentum network, given the collocation points in the computational domain as input to the network.
            - `pred_T_XYf` (`torch.Tensor`): The output (predictions) from the Energy network, given the collocation points in the computational domain as input to the network.
            - `XYf` (`torch.Tensor`): The collocation points in the computational domain used to compute the PDE residuals.

        Returns:
        -------
            `torch.Tensor`: Returns the scalar overall loss value of the PDE residuals in the entire computational domain.
        """
        # Unpack the predicted individual flow variables corresponding to the dependent quantities in the momentum equations, from the neural network output
        u = pred_u_XYf[:,0:1]
        v = pred_u_XYf[:,1:2]
        p = pred_u_XYf[:,2:3]
        uu = pred_u_XYf[:,3:4]
        uv = pred_u_XYf[:,4:5]
        vv = pred_u_XYf[:,5:6]
        s11 = pred_u_XYf[:,6:7]
        s12 = pred_u_XYf[:,7:8]
        s22 = pred_u_XYf[:,8:9]
        
        # Unpack the predicted temperature variable from the neural network output
        T = pred_T_XYf[:,0:1]
        # uT = pred_T_XYf[:,1:2]
        # vT = pred_T_XYf[:,2:3]
        # qx = pred_T_XYf[:,3:4]
        # qy = pred_T_XYf[:,4:5]
        
        # Unpack the x-y coordinates of the collocation points in the interior domain
        Xf = XYf[:,0:1]
        Yf = XYf[:,1:2]
        
        # Calculate all the derivatives using the torch.autograd.grad functionality
        dU_dX = grad(u,Xf,grad_outputs=torch.ones_like(u))[0].    <-------------------------
        dU_dY = grad(u,Yf,grad_outputs=torch.ones_like(u))[0]
        dV_dX = grad(v,Xf,grad_outputs=torch.ones_like(v))[0] 
        dV_dY = grad(v,Yf,grad_outputs=torch.ones_like(v))[0] 
        
        duu_dX = grad(uu,Xf,grad_outputs=torch.ones_like(uu))[0]
        dvv_dY = grad(vv,Yf,grad_outputs=torch.ones_like(vv))[0]
        
        duv_dX = grad(uv,Xf,grad_outputs=torch.ones_like(uv))[0]
        duv_dY = grad(uv,Yf,grad_outputs=torch.ones_like(uv))[0]
        
        dS11_dX = grad(s11,Xf,grad_outputs=torch.ones_like(s11))[0]
        dS12_dX = grad(s12,Xf,grad_outputs=torch.ones_like(s12))[0] 
        dS12_dY = grad(s12,Yf,grad_outputs=torch.ones_like(s12))[0] 
        dS22_dY = grad(s22,Yf,grad_outputs=torch.ones_like(s22))[0] 
        
        dUU_dX = grad((u*u),Xf,grad_outputs=torch.ones_like((u*u)))[0]  
        dUV_dX = grad((u*v),Xf,grad_outputs=torch.ones_like((u*v)))[0]  
        dUV_dY = grad((u*v),Yf,grad_outputs=torch.ones_like((u*v)))[0]  
        dVV_dY = grad((v*v),Yf,grad_outputs=torch.ones_like((v*v)))[0]  
        
        # Establish the RANS equations using the above constructed partial derivatives using the Autograd Engine
        #-----Continuity Equation------#
        epsilon1 = dU_dX + dV_dY # Residual from continuity equation
        
        #-----X - Momentum Equation----#
        epsilon2 = dUU_dX + dUV_dY + duu_dX + duv_dY - dS11_dX - dS12_dY # Residual from X-Momentum equation
        
        #-----Y - Momentum Equation------#
        # epsilon3 = dUV_dX + dVV_dY + duv_dX + dvv_dY - dS12_dX - dS22_dY # Residual from Y-Momentum equation (For the case of non-buoyancy driven flow)

        epsilon3 = dUV_dX + dVV_dY + duv_dX + dvv_dY - dS12_dX - dS22_dY + self.Ri*T # Residual from Y-Momentum equation (For the case of buoyancy driven flow)
        
        #-----Shear stress equations-----#
        epsilon4 = s11 + p - (2.0/self.Re)*(dU_dX)
        epsilon5 = s22 + p - (2.0/self.Re)*(dV_dY)
        epsilon6 = s12 - (1.0/self.Re)*(dU_dY + dV_dX)
        epsilon7 = p + 0.5*(s11 + s22)
        
        # Calculate the lossess of Momentum equations
        loss_PDE_Continuity = self.loss_criterion(epsilon1,torch.zeros_like(epsilon1)) # MSE loss of prediction vs target values
        loss_PDE_X_Momentum = self.loss_criterion(epsilon2,torch.zeros_like(epsilon2))
        loss_PDE_Y_Momentum = self.loss_criterion(epsilon3,torch.zeros_like(epsilon3))
        loss_PDE_ShearXX = self.loss_criterion(epsilon4,torch.zeros_like(epsilon4))
        loss_PDE_ShearYY = self.loss_criterion(epsilon5,torch.zeros_like(epsilon5))
        loss_PDE_ShearXY = self.loss_criterion(epsilon6,torch.zeros_like(epsilon6))
        loss_PDE_Pressure = self.loss_criterion(epsilon7,torch.zeros_like(epsilon7))
        
        # Return the Total momentum loss
        return (loss_PDE_Continuity + loss_PDE_X_Momentum + loss_PDE_Y_Momentum + loss_PDE_ShearXX + loss_PDE_ShearYY + loss_PDE_ShearXY + loss_PDE_Pressure)

In the Training loop, I use the DataLoader to get a batch of data and then send it to the GPU and activate the requires_grad property of the tensor with respect to which I want to compute the derivatives to setup the physics equations. The train loop is as follows:

def train_pinn_rans_energy_model(self,useLBFGS:bool = False) -> None:
        """
        Summary:
        -------
        Sub-routine to train the PINN_RANS_Energy neural network model.

        Args:
        ----
            - `useLBFGS` (bool, optional): Flag to train the network only with Adam optimizer for the entire epochs or to switch to LBFGS optimizer after a specified number of iterations for the Adam optimizer. Defaults to False.
        """
        self.dnn1.train(True) # Setting the network to training mode
        self.dnn2.train(True)
        
        # Start the training loop with ADAM optimizer for a specified number of epochs and then switch to LBFGS optimizer
        max_consecutive_no_improvement = 2000
        consecutive_no_improvement = 0
        best_loss_mom = float('inf')  # Initialize with a high value for Momentum loss
        best_loss_en = float('inf') # Initialize with a high value for Energy loss

        batch_idx = int(0)

        # Training loop begins here (Outer loop loops through the epoch and inner loop, loops through the batches of data)
        for epoch in range(self.num_epochs_adam):
            
            # Append the epoch into the list
            self.epoch.append(epoch) 

            # Loop over batches of data and compute the forward pass through the network and then calculate the total loss from the network predictions
            # Batch_idx is equivalent to the iteration over the mini batches of data
            for (col_batch_data,bc_batch_data,DNS_batch_data) in zip(self.collocation_points_dataloader,
                                                                     self.boundary_conditions_dataloader,
                                                                     self.DNS_data_dataloader):
                # Store the batch_idx in the iterator variable
                self.iter.append(batch_idx)

                # Zero the gradients of the Adam optimizers
                self.optimizer_adam1.zero_grad() 
                self.optimizer_adam2.zero_grad()
                
                # Unpack the tensors into their respective variables -> send it to the appropriate device and turn on the requires_grad property of the tensors.
                # Fetching new data batch
                # Batched Collocation points
                XYf_Train = col_batch_data[0].to(self.device).requires_grad_() # Send the batched collocation points to the device and activate the requires_grad property to track the gradients.
                XYb_Train = bc_batch_data[0].to(self.device).requires_grad_() # Send the batched boundary collocation points to the device and activate the requires_grad property to track the gradients.
                XYd_Train = DNS_batch_data[0].to(self.device)

                # Get target Boundary conditions and DNS data batch
                Ub_Train = bc_batch_data[1].to(self.device) # Send the target boundary conditions tensor to the device
                Ud_Train = DNS_batch_data[1].to(self.device) # Send the target DNS data tensor to the device

                # Slice the tensor to get the (Momentum boundary conditions, Momentum DNS data) and (Temperature Boundary conditions, Temperature DNS data)
                Ub = Ub_Train[:,0:5]
                Tb = Ub_Train[:,5:8]

                Ud = Ud_Train[:,0:5]
                Td = Ud_Train[:,5:8]

                # Forward pass the network inputs to obtain the predictions from the momentum network
                pred_u_XYf = self.dnn1(XYf_Train) # Prediction from the Momentum network for Domain collocation points
                pred_T_XYf_ = self.dnn2(XYf_Train) # Prediction from the updated Energy network for Domain collocation points
                pred_u_XYb = self.dnn1(XYb_Train) # Prediction from the Momentum network for Boundary collocation points
                pred_u_XYd = self.dnn1(XYd_Train) # Prediciton from the Momentum network for the DNS data coordinate points

                # Calculate the losses for the Momentum and Energy networks
                l_momentum,l_pde_momentum,l_bc_momentum,l_data_momentum = self.calculate_momentum_losses(pred_u_XYf=pred_u_XYf,
                                                                                                         pred_T_XYf=pred_T_XYf_,
                                                                                                         pred_u_XYb=pred_u_XYb,
                                                                                                         pred_u_XYd=pred_u_XYd,
                                                                                                         XYf=XYf_Train,
                                                                                                         XYb=None,
                                                                                                         Ub=Ub,
                                                                                                         Ud=Ud)
                
                # Backpropagate the total loss to compute the gradients with respect to the network parameters and update the network parameters for the Momentum network.
                l_momentum.backward()
                self.optimizer_adam1.step()

                # Forward pass the network inputs to obtain the predictions from the energy network
                pred_u_XYf_ = self.dnn1(XYf_Train) # Prediction from the updated Momentum network for Domain collocation points
                pred_T_XYf = self.dnn2(XYf_Train) # Prediction from the Energy network for Domain collocation points
                pred_T_XYb = self.dnn2(XYb_Train) # Predicition from the Energy network for Boundary collocation points
                pred_T_XYd = self.dnn2(XYd_Train) # Prediction from the Energy network for the DNS data coordinate points

                l_energy,l_pde_energy,l_bc_energy,l_data_energy = self.calculate_energy_losses(pred_u_XYf=pred_u_XYf_,
                                                                                               pred_T_XYf=pred_T_XYf,
                                                                                               pred_T_XYb=pred_T_XYb,
                                                                                               pred_T_XYd=pred_T_XYd,
                                                                                               XYf=XYf_Train,
                                                                                               XYb=XYb_Train,
                                                                                               Tb=Tb,
                                                                                               Td=Td)

                # Backpropagate the total loss to compute the gradients with respect to the network parameters and update the network parameters for the Energy network.
                l_energy.backward()
                self.optimizer_adam2.step()

                # Append the loss values to the list to fetch the training data later
                with torch.no_grad():
                    self.accumulate_momentum_loss(l_pde_momentum,l_bc_momentum,l_data_momentum,l_momentum)
                    self.accumulate_energy_loss(l_pde_energy,l_bc_energy,l_data_energy,l_energy) 

                if self.verbose:
                    if (self.iter[batch_idx]%10) == 0:
                        print(f"Epoch: {self.epoch[-1]:6d} ||"
                              f" Iteration: {self.iter[-1]:6d} ||"
                              f" Loss_M: {self.loss_total_momentum[-1]:0.5e} ||"
                              f" LossF_M: {self.loss_pde_momentum[-1]:0.5e} ||"
                              f" LossB_M: {self.loss_bc_momentum[-1]:0.5e} ||"
                              f" LossD_M: {self.loss_data_momentum[-1]:0.5e} ||"
                              f" Loss_E: {self.loss_total_energy[-1]:0.5e} ||"
                              f" LossF_E: {self.loss_pde_energy[-1]:0.5e} ||"
                              f" LossB_E: {self.loss_bc_energy[-1]:0.5e} ||"
                              f" LossD_E: {self.loss_data_energy[-1]:0.5e}")
                        
                # Check the last 100 epochs have the same total loss values, If so exit the training loop and inform the user that the loss has converged
                # Check if the loss has improved
                if (self.loss_total_momentum[-1] < best_loss_mom) or (self.loss_total_energy[-1] < best_loss_en):
                    best_loss_mom = self.loss_total_momentum[-1]
                    best_loss_en = self.loss_total_energy[-1]
                    consecutive_no_improvement = 0
                else:
                    consecutive_no_improvement += 1

                # Check if we should stop training
                if consecutive_no_improvement >= max_consecutive_no_improvement:
                    print(f"\nTraining stopped after {self.iter[-1]} iterations due to loss convergence.")
                    break
                
                # Check if the value is NaN and if so, break the loop
                if torch.isnan(l_momentum.detach()) or torch.isnan(l_energy.detach()):
                    print('\nSolution is expected to diverge...Choose different set of hyperparameters for the network.')
                    break

                # Iterate the batch index variable
                batch_idx += 1

            # If the inner loop breaks, break the outer loop as well by checking consecutive_no_improvement variable
            if consecutive_no_improvement >= max_consecutive_no_improvement:
                break

            # Check if the value is NaN and if so, also break the outer epoch loop as well
            if torch.isnan(l_momentum.detach()) or torch.isnan(l_energy.detach()):
                break

For some peculiar reason, I’m not able to understand why there is a problem in calculating the gradient of u w.r.t to Xf (Shown with a long arrow in the code above. Can somebody please help me with this at the earliest possible please?

The error that I get is this

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Please somebody help me out…:frowning:

This often means that you have performed an operation that is not able to be tracked by autograd. Usually this is creating a new tensor, doing non-pytorch ops, etc.