Efficient Calculation of Derivatives for PINN Solvers in PyTorch

Hi, I am currently trying to implement Physics Informed Neural Networks (PINNs). PINNs involve computing derivatives of model outputs with respect to its inputs. These derivatives are then used to calculate PDE residuals which could be Heat, Burger, Navier-Stokes Equation etc. Therefore, one needs to compute higher order partial derivatives. I tried to use torch.autograd.grad to compute those partial derivatives. Here is the function I have implemented:

def diff(y, xs):
    grad = y
    ones = torch.ones_like(y)
    for x in xs:
        grad = torch.autograd.grad(grad, x, grad_outputs=ones, create_graph=True)[0]
    return grad

diff(y, xs) simply computes y’s derivative with respect to every element in xs. This way denoting and computing partial derivatives is much easier:

y_pred = model([t, x, y])
diff(y_pred, [x, x]) # dydxx
diff(y_pred, [x, y, t]) # dydxyt

This function allows me to write PDEs in a much compact and easier way, as an example here is the Heat Equation in 1D:

u = model([x, t])
def heat(u, x, t):
    alpha = 1.0
    return diff(u, [x, x]) - alpha * diff(u, [t])

Physics Informed Neural Networks (PINNs) are a promising approach to solve complex partial differential equations (PDEs) by combining the power of neural networks and physics-based constraints. However, one of the limitations of PINNs other than data is the training time. To reduce the training time I tried using torch.jit.script, torch.jit.trace and torch.compile.

  • JIT Script implementation of the diff(y, xs):
    def D(y: torch.Tensor, xs: List[torch.Tensor]) -> torch.Tensor:
        func: torch.Tensor = y
        ones: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
        for x in xs:
            grad = torch.autograd.grad(
            func = grad if grad is not None else torch.zeros_like(x)
        return func
  • I then tried to implement the heat equation as well but was stuck with an error:
    • JIT Trace:
      N = 1000
      x = torch.linspace(0.0, 1.0, N, requires_grad=True, device=device).reshape((N,1))
      t = torch.linspace(0.0, 1.0, N, requires_grad=True, device=device).reshape((N,1))
      u = torch.sin(x * np.pi) * torch.exp(-t**2) # A typical solution to the Heat Eqn.
      def heat(u, x, t):
         alpha = 1.0
         res = D(u, [x, x]) - alpha * D(u, [t]) # Uses jit.script diff
         return res
      traced_heat = torch.jit.trace(heat, (u, x, t))
      The error:
       Runtime  Error: The following operation failed in the TorchScript interpreter.
       Traceback of TorchScript (most recent call last):
         File "/tmp/ipykernel_19890/2258731968.py", line 6, in D
           ones: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
           for x in xs:
               grad = torch.autograd.grad(
                      ~~~~~~~~~~~~~~~~~~~ <--- HERE
       RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I have no idea of what is causing this, I have no experiment with PyTorch but wanted to export what I have written (in Tensorflow) during my studies since PyTorch provides more flexibility and control. However, with this much flexibility I have no idea which option should I be using. Should I try TorchDynamo instead?

  • If I run torch._dynamo.explain on heat function I get the following explanation:
    Dynamo produced 1 graphs with 0 graph break and 2 ops
     Break reasons: 
    TorchDynamo compilation metrics:
    Function, Runtimes (s)
    _compile, 0.0880, 0.0152, 0.0046
    OutputGraph.call_user_compiler, 0.0000
  • Then, if I try to compile the heat function, I get the following error:
    compheat = torch.compile(heat, backend='inductor', fullgraph=True)
    r = compheat(u, x, t)
    Unsupported: inlining disallowed: <function grad at 0x7f36642888b0>
    from user code:
       File "/tmp/ipykernel_19890/543225163.py", line 3, in cheat
        res = D(u, [x, x]) - alpha * D(u, [t])
      File "/tmp/ipykernel_19890/2258731968.py", line 6, in D
        grad = torch.autograd.grad(
    You can suppress this exception and fall back to eager by setting:
        torch._dynamo.config.suppress_errors = True

I am open to both suggestions on design and solutions for the errors that I have. I am new to PyTorch and have limited experience with the framework, but I am eager to learn and apply its capabilities to solve complex problems like PDEs using PINNs. I apologize in advance if any of the information I have provided is incorrect or incomplete due to my limited experience with PyTorch. I am here to learn and would appreciate any corrections or feedback from the community. Thank you for your understanding.

Please note that as a toy example I have used the following NN, which first takes the inputs as a list, and stacks them before feeding it to the actual layers:

class NeuralNet(nn.Module):
    def __init__(self, n_input):
        self.linear_sigmoid_stack = nn.Sequential(
            nn.Linear(n_input, 32),
            nn.Linear(32, 32),
            nn.Linear(32, 1)
    def forward(self, x):
        x = torch.hstack(x)
        return self.linear_sigmoid_stack(x)
1 Like