JVP and checkpointing

I would like to compute a Jacobian vector product (JVP) with a large network. However, the transformer model uses checkpointing. If I compute the JVP, I get the error:
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")

Is this simple not possible then or are there solutions for this problem? I do not want to backpropagate through the output of the jvp (so I do not need checkpointing for that part of the code, e.g. if I could simply disable checkpointing for these lines of code, that would also be okay)

Do you have a stack trace of the error? You should only see that if you’re running backward.