Calculate the hessian trace for the intermediate inputs

This library provides a method to calculate parameters’ Hessian trace using the Hutchinson’s method. In this method, you need to get the params and grads and then calculate them using torch.autograd.grad(gradsH, params). The code is as below

def trace(self, maxIter=100, tol=1e-3):
        """
        compute the trace of hessian using Hutchinson's method
        maxIter: maximum iterations used to compute trace
        tol: the relative tolerance
        """
  
        device = self.device
        trace_vhv = []
        trace = 0.
  
        for i in range(maxIter):
            self.model.zero_grad()
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1
  
            if self.full_dataset:
                _, Hv = self.dataloader_hv_product(v)
            else:
                Hv = hessian_vector_product(self.gradsH, self.params, v)
            trace_vhv.append(group_product(Hv, v).cpu().item())
            if abs(np.mean(trace_vhv) - trace) / (abs(trace) + 1e-6) < tol:
                return trace_vhv
            else:
                trace = np.mean(trace_vhv)
  
        return trace_vhv



def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv

But when calculating the Hessian trace of intermediate inputs, I can obtain intermediate inputs and corresponding grads input_grads by hook, but I cannot use torch.autograd.grad(input_grads, inputs), for input_grads don’t have grad_fn. Is there any way to use torch.autograd.grad for intermediate inputs?