Hi, I was able to implement a class for Jn to have arbitrary derivatives as shown below:
import numpy as np
from scipy.special import jv
import torch
class besselJv(torch.autograd.Function):
@staticmethod
def forward(ctx, x, v):
ctx.save_for_backward(x, v)
return jv(v, x)
@staticmethod
def backward(ctx, grad_out):
x, v = ctx.saved_tensors
return grad_out*0.5*(besselJv.apply(x, v-1) - besselJv.apply(x, v+1)), None
x = torch.tensor([2.0], dtype=torch.double, requires_grad=True)
v = torch.tensor(1)
y = besselJv.apply(x, v)
dy_dx = torch.autograd.grad(y, x, create_graph=True)
dy2_dx2 = torch.autograd.grad(dy_dx, x)
print(dy_dx[0].item())
print(dy2_dx2[0])
However when I try implementing this using v
as a constant and not a tensor following this post:
import numpy as np
from scipy.special import jv
import torch
class besselJv(torch.autograd.Function):
@staticmethod
def forward(ctx, x, v):
ctx.save_for_backward(x)
ctx._v = v
return torch.from_numpy(jv(v, x.detach().numpy()))
@staticmethod
def backward(ctx, grad_out):
x = ctx.saved_tensors
v = ctx._v
return grad_out*0.5*(besselJv.apply(x, v-1) - besselJv.apply(x, v+1)), None
x = torch.tensor([2.0], dtype=torch.double, requires_grad=True)
v = torch.tensor(1)
y = besselJv.apply(x, v)
dy_dx = torch.autograd.grad(y, x, create_graph=True)
dy2_dx2 = torch.autograd.grad(dy_dx, x)
print(dy_dx[0].item())
print(dy2_dx2[0])
I get an error with this line:
What is the issue here? Is this code outdated?
Thank you,
Alex