Hi, I’m getting an error due to passing an argument (I think its the float k*r_trans) when using extra arguments to the function jacrev()
. I tried declaring this as a tensor and then converting it back to a float inside the function pressureSpatialCompRealFunc
but then it says the new variable doesnt have storage. Is there a way of passing this variable? The code is below.
Also, when using jacfwd()
on the output of jacrev()
, do I need the extra parameters?
Thank you,
Alex
import numpy as np
from scipy.special import jv
import torch
import time
#Autograd class declaration of function J_v(sqrt(x)) for pytorch
#This implementation does not have the machine precision error due to /sqrt(x) when x is or is close to 0
class modJv(torch.autograd.Function):
@staticmethod
def forward(ctx, x, v, a):
ctx.save_for_backward(x)
ctx._v = v
ctx._a = a
return torch.from_numpy(jv(v, np.array(a*np.sqrt(x.detach().numpy()))))
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
v = ctx._v
a = ctx._a
Jv = modJv.apply(x, v,a)
Jvm2 = modJv.apply(x, v-2,a)
Jv2 = modJv.apply(x, v+2,a)
return grad_out*a*a/8.0*((Jv+Jvm2)/(v-1) - (Jv2+Jv)/(v+1)), None, None
#Function to compute parts of sound field that are spatialy dependent and don't have complex exponents
def pressureSpatialCompFunc(point, transLocation,transNormal, kr_trans):
dirV = 1 #Nu parameter for Bessel function in original form of transducer field directivity function
distVec = point-transLocation
dist2 = distVec.pow(2).sum().unsqueeze(dim=0)
sinTheta2 = torch.cross(distVec,transNormal).pow(2).sum().unsqueeze(dim=0)/dist2
directivity = 0.25*(modJv.apply(sinTheta2,dirV+1,kr_trans) + modJv.apply(sinTheta2,dirV-1,kr_trans))
dist = dist2.sqrt()
return directivity/dist, dist
#All values in SI units
c_m = 332.0 #Speed of sound in air (m/s)
f_trans = 40.0e3 #Transducer frequency (40 kHz)
k = 2*np.pi*f_trans/c_m #Sound wavenumber
r_trans = 5.0e-3 #Transducer radius (5 mm)
#Transducer location
transLocation = torch.tensor([4.0e-3, 2.0e-3, 1.0e-3], dtype=torch.double, requires_grad=False)
#Transducer unit normal vector
transNormal = torch.tensor([0.0, 0.0, 1.0 ], dtype=torch.double, requires_grad=False)
#Point where sound field is measured
point = torch.tensor([4.0e-3, 2.0e-3, 20.0e-3], dtype=torch.double, requires_grad=True)
#Compute real and imaginary parts of spatialy depedent component of transducer sound field
def pressureSpatialCompRealFunc(point, transLocation,transNormal, kr_trans):
p_spatial_comp, tpDist = pressureSpatialCompFunc(point, transLocation,transNormal, kr_trans)
p_spatial_comp_real = p_spatial_comp*torch.sin(k*tpDist)
start_time = time.time()
#p_spatial_comp, tpDist = pressureSpatialCompFunc(point, transLocation,transNormal, kr_trans)
#p_spatial_comp_real = p_spatial_comp*torch.sin(k*tpDist)
#jac_f = torch.autograd.grad(p_spatial_comp_real, point, create_graph=True)
jac_f = torch.func.jacrev(pressureSpatialCompRealFunc,argnums=3)(point, transLocation,transNormal,k*r_trans)
jac_time = time.time() - start_time
print("--- %s seconds ---" % jac_time)
start_time = time.time()
hess_f = torch.func.jacfwd(jac_f,argnums=0)(point)
hess_time = time.time() - start_time
print("--- %s seconds ---" % hess_time)
print("--- %s seconds ---" % (jac_time + hess_time))
print(jac_f)
print(hess_f)