Calculating Second Derivative

Hi All,

I am relatively new to PyTorch, and I am trying to find the second derivative. However, it is always zero for some reason. Below is the code:

import torch
from torch.autograd import grad
dev = torch.device('cpu')
if torch.cuda.is_available():
    dev = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
def Sec_Der(y,x):
        
    duxdxyz = grad(y[:, 0].unsqueeze(1), x, torch.ones(x.size()[0], 1, device=dev), create_graph=True, retain_graph=True)[0]
    duydxyz = grad(y[:, 1].unsqueeze(1), x, torch.ones(x.size()[0], 1, device=dev), create_graph=True, retain_graph=True)[0]
    duzdxyz = grad(y[:, 2].unsqueeze(1), x, torch.ones(x.size()[0], 1, device=dev), create_graph=True, retain_graph=True)[0]
        
    sec_der = grad(duzdxyz[:,0].unsqueeze(1), x, torch.ones(x.size()[0], 1), create_graph=True, retain_graph=True)[0]
    print(sec_der)
            
    return sec_der
x = 2*torch.rand((5,3))
x.requires_grad_(True)
x.retain_grad()
print(x)
y = x**3 + 2*x
print(y)
sec_der = Sec_Der(y,x)

I think you tricked yourself with the indexing.
As y[:, 2] only depends on the last column in x, only the last column of the gradient is nonzero (do print(duzdxyz)). Then you take the sec_der as the grad of an all-zero column of duzdxyz.

Best regards

Thomas

1 Like

Many thanks, @tom, for the reply.

I see. Yes, it seems that I tricked myself with the indexing. But even if I defined y differently, I still get zeros. Let’s say that

m = torch.nn.Linear(3, 3)
y = m(torch.relu(m(torch.relu(m(x)))))

Now, print(duzdxyz) are not zeros

Now you tricked yourself with too much linearity.
While relu is nonlinear globally (else we would famously not have the universal representation property) it is linear in a neighborhood of almost every input (if you pardon the mathematics lingo), rendering your network linear in a neighborhood of your inputs.
Taking a more global view, the derivative is piecewise constant, and the second derivative 0 where defined.

1 Like

Thanks for the explanation @tom. It makes sense. In this case, what type of layers can I use to capture the second derivative of data? Maybe, I keep the linear function but use other activations? Any recommendations?

At the risk of being a math nerd: Don’t say “the (second) derivative of the data”. It is the "derivative of with respect to " (well actually you typically compute hession-vector products when using second derivatives with backprop, but hey).
Many (most?) of the loss functions are sufficiently nonlinear, but you would, of course also get second derivatives of the function specified by your neural network if you used an activation function that is curved (e.g. tanh where it is not saturated).

2 Likes

:slight_smile: Thanks @tom