The following code works, but is slow. How can I speed this up? Thanks!
import torch
import time
class Kernel:
def __call__(self,x,z,alpha,beta):
x,z,alpha,beta = torch.atleast_2d(x),torch.atleast_2d(z),torch.atleast_2d(alpha),torch.atleast_2d(beta)
n,m,d = len(x),len(z),len(x[0])
assert x.shape==(n,d) and z.shape==(m,d) and alpha.shape==(n,d) and beta.shape==(m,d)
return torch.tensor([[self._kd(x[i],z[j],alpha[i],beta[j],d) for j in range(m)] for i in range(n)])
def _kd(self,x,z,alpha,beta,d):
y = self.k(x,z)
for i in range(d):
for j in range(alpha[i]):
y = torch.autograd.grad(y,x,create_graph=True)[0][i]
for j in range(beta[i]):
y = torch.autograd.grad(y,z,create_graph=True)[0][i]
return y
class GaussianKernel(Kernel):
def __init__(self,l=1.):
self.l = l
def k(self,x,z):
return torch.exp(-((x-z)**2).sum()/self.l)
kernel = GaussianKernel()
nx = 100
# no derivatives, takes about .5 sec
x = torch.rand(nx,1,requires_grad=True)
alpha = torch.zeros(nx,1,dtype=torch.int)
t0 = time.time()
kmat = kernel(x,x,alpha,alpha)
print("time %.1e"%(time.time()-t0))
# with derivatives, takes about 10 sec
n4 = nx//4; assert (nx/4)%1==0
alpha_deriv = torch.tensor([0]*n4+[1]*n4+[2]*n4+[3]*n4,dtype=torch.int)[:,None]
t0 = time.time()
kmat = kernel(x,x,alpha_deriv,alpha_deriv)
print("time %.1e"%(time.time()-t0))