When I try to use loggamma and digamma functions, it gives me “RuntimeError: polygamma(n,x) is not implemented for n>=2, but was 2” error. Here is a sample code reproducing the error:
import torch
init_var = torch.randn([1,2], requires_grad = True)
x = torch.randn([1,2])
y = (init_var*x)
gt = torch.randn([1,2])
log_gamma_res = torch.digamma(y) #y
loss = (gt - log_gamma_res).sum()
grad = torch.autograd.grad(loss,init_var, create_graph = True)[0]
loss2 = (y - grad).sum()
loss2.backward()
It would be great if someone can explain this error and suggest how I canfix this error?