Autograd Confusion

Hi all, I’m very new here and I’m currently working on implementing a PDE solver in Pytorch. I am trying to implement the following PDE (Black-Scholes Equation):


and I will use this as a target function to measure my loss against (in this case the loss will be mean(sum(f^2)))

I have written this as below:

 # Loss term #1: PDE
    V = model(sample1)
    grad_V = torch.autograd.grad(V.sum(), sample1, create_graph=True, retain_graph=True)[0]
    V_t = grad_V[:,0]
    V_s = grad_V[:,1]
    V_ss = torch.autograd.grad(V_s.sum(), sample1, create_graph=True, retain_graph=True)[0][:,1]
    
    f = V_t + 0.5 * sigma**2 * sample1[:,1]**2 * V_ss + r * sample1[:,1] * V_s - r*V

    #print(V_t, V_x, V_xx)
    L1 = torch.mean(torch.pow(f, 2))

where “Sample” is an Nx2 tensor where sample[:,0] represents a sample of random times (t) and sample[:,1] represents a sample of random values of S.

My question is, I suspect this code doesn’t compute the derivatives I think it does as my answer after training is poor, moreover I am unsteady on which graphs should be retained and which should be created, so can anyone help to decipher how I can get this snippet of code to perform the calculations, as written in the equation above?

Many Thanks in advance

EDIT: it is also worth noting that if I try to replace the code above with

 # Loss term #1: PDE
    V = model(sample1)
    grad_V = torch.autograd.grad(V.sum(), sample1, create_graph=True, retain_graph=True)[0]
    V_t = grad_V[:,0]
    V_s = grad_V[:,1]
    V_ss = torch.autograd.grad(V_s.sum(), sample1[:,1], create_graph=True, retain_graph=True)[0]
    
    f = V_t + 0.5 * sigma**2 * sample1[:,1]**2 * V_ss + r * sample1[:,1] * V_s - r*V

    #print(V_t, V_x, V_xx)
    L1 = torch.mean(torch.pow(f, 2))

Which I would have thought would be equivalent. I receive an error of “One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.” Setting allow_unused=True results in the gradient being equal to 0.