RuntimeError: derivative for aten::_scaled_dot_product is not implemented

Hi Guys,

I am trying to calculate the hessian matrix of GPT however when I am trying to calculate the grad I run to this RuntimeError: derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented.

How can I fix this error?

I added this

y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) 

in the attention network but it still doesn’t work.

Below is the code that raises the error.


from torch.autograd import grad

def hessian_vector_product(vector):
    model.zero_grad()
    grad_params = grad(compute_loss(), param_extract_fn(model), create_graph=True)
    flat_grad = torch.cat([g.view(-1) for g in grad_params])
    grad_vector_product = torch.sum(flat_grad * vector)
    hvp = grad(grad_vector_product, param_extract_fn(model), retain_graph=True)
    
    return torch.cat([g.contiguous().view(-1) for g in hvp])

hessian_matrix = torch.zeros(num_params, num_params)

for i in range(num_params):
    
    grad_i = grad(hessian_vector_product(torch.eye(num_params, device=device)[i]), param_extract_fn(model), retain_graph=True)
    hessian_matrix[i] = t.cat([g.contiguous().view(-1) for g in grad_i])

eigenvalues, eigenvectors = t.eig(hessian_matrix, eigenvectors=True)
eigenvalues = eigenvalues[:, 0].cpu().detach().numpy()  # Extract real parts of eigenvalues
eigenvectors = eigenvectors.t().cpu().detach().numpy()  # Transpose eigenvectors
RuntimeError                              Traceback (most recent call last)
Cell In[24], line 44
     41 hessian_matrix
     42 for i in range(num_params):
---> 44     grad_i = grad(hessian_vector_product(torch.eye(num_params, device=device)[i]), param_extract_fn(model), retain_graph=True)
     45     hessian_matrix[i] = t.cat([g.contiguous().view(-1) for g in grad_i])
     47 # eigenvalues, eigenvectors = t.eig(hessian_matrix, eigenvectors=True)
     48 # eigenvalues = eigenvalues[:, 0].cpu().detach().numpy()  # Extract real parts of eigenvalues
     49 # eigenvectors = eigenvectors.t().cpu().detach().numpy()  # Transpose eigenvectors

Cell In[24], line 27, in hessian_vector_product(vector)
     25 flat_grad = torch.cat([g.view(-1) for g in grad_params])
     26 grad_vector_product = torch.sum(flat_grad * vector)
---> 27 hvp = grad(grad_vector_product, param_extract_fn(model), retain_graph=True)
     29 return torch.cat([g.contiguous().view(-1) for g in hvp])

File ~/pytorchenv/lib/python3.10/site-packages/torch/autograd/__init__.py:411, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    407     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
    408         grad_outputs_
    409     )
    410 else:
--> 411     result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    412         t_outputs,
    413         grad_outputs_,
    414         retain_graph,
    415         create_graph,
    416         inputs,
    417         allow_unused,
    418         accumulate_grad=False,
    419     )  # Calls into the C++ engine to run the backward pass
    420 if materialize_grads:
    421     if any(
    422         result[i] is None and not is_tensor_like(inputs[i])
    423         for i in range(len(inputs))
    424     ):

RuntimeError: derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented