2nd order differential for PINN model

X is [n,2] matric which compose x and t. I am using Pytorch to compute differential of u(x,t) wrt to X to get du/dt and du/dx and du/dxx. Here is my piece of code

X.requires_grad = True
p = mlp(X)
grads, = torch.autograd.grad(p, X, grad_outputs=p.data.new(p.shape).fill_(1),create_graph=True, only_inputs=True)
grads1, = torch.autograd.grad(grads, X, grad_outputs=grads.data.new(grads.shape).fill_(1), create_graph=True, only_inputs=True)
dpdx, dpdt = grads[:,0], grads[:,1]
dpdxx = grads1[:,0]

I am not getting satisfactory result. I think something is wrong when i am calculating 2 differential. can anyone help me?

Hello there !

I see you’re using an MLP, depending on the activation functions you have, if you use 2nd order derivatives for ReLU activations, you should have all zeros, be careful with that. Tanh or sigmoid activation should do the trick.

In addition, if you want to differentiate your neural net wrt t or x, I advise you to be a bit more careful with how you write your derivatives.

Here’s a little example:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.l1=nn.Linear(2,80)
        self.mi=nn.Linear(80,80)
        self.mi1=nn.Linear(80,80)
        self.mi2=nn.Linear(80,40)
        self.ol=nn.Linear(40,1)
        self.th=nn.Tanh()
            
    def forward(self,x,t):
        u=torch.cat((x,t),1)
        u=self.th(self.l1(u))
        u=self.th(self.mi(u))
        u=self.th(self.mi1(u))
        u=self.th(self.mi2(u))
        u=self.th(self.ol(u))
        
        return u
def f(x,t):
    u = net(x,t)
    u_t = grad(u, t, 
                     create_graph=True,
                     grad_outputs=torch.ones_like(u), 
                     allow_unused=True)[0]
    u_x = grad(u, x, 
                     create_graph=True,
                     grad_outputs=torch.ones_like(u), 
                     allow_unused=True)[0]
    u_xx = grad(u_x, x, 
                     create_graph=True,
                     grad_outputs=torch.ones_like(u), 
                     allow_unused=True)[0]
    w = torch.tensor(0.01/np.pi)
    f = u_t + u*u_x - w*u_xx
    return f

Here’s an example for Burgers’ equation. I got the original code from NNets-and-Diffeqns/burgernet.py at master · dbgannon/NNets-and-Diffeqns · GitHub
It’s explained here: Notes on Deep Learning and Differential Equations. – Cloud Computing For Science and Engineering

Hope it helps !