When I try to use loggamma and digamma functions, it gives me “RuntimeError: polygamma(n,x) is not implemented for n>=2, but was 2” error. Here is a sample code reproducing the error:

import torch

init_var = torch.randn([1,2], requires_grad = True)

x = torch.randn([1,2])

y = (init_var*x)

gt = torch.randn([1,2])

log_gamma_res = torch.digamma(y) #y

loss = (gt - log_gamma_res).sum()

grad = torch.autograd.grad(loss,init_var, create_graph = True)[0]

loss2 = (y - grad).sum()

loss2.backward()

It would be great if someone can explain this error and suggest how I canfix this error?