Getting the Jacobian of transformer encoder outputs with respect to the inputs


Suppose I provide an input to a transformer encoder with batch_first = True, I want the Jacobian of a row of the output with respect to a row of the input. Suppose the shape of the input tensor I, is (B, M, D) and the output tensor is O, where B is the batch size, how can I get the Jacobians of O[ :, i] with respect to I[ :, j]. The output should be B Jacobians, each of size D x D. Using indexing seems to break the computation graph and so I get a None grad, and using jacfwd to get a 6D tensor seems like it would be too slow. How can this be done efficiently?

Indexing shouldn’t break the computation graph, does something like the following work?

input = <original input>
def fn(sliced_input):
    inner_input = input.clone()[:, j].copy_(sliced_input)
   output = fn(inner_input)
   return output[:, i]

jacobian(fn)(input[:, j])

You’re right, the computation graph doesn’t break, I’ve found the actual issue to be that while the input tensor has requires grad set to true, calling torch.autograd.grad with a slice of the input and allow_unused set to true, returns none, since using the slice passes a tensor with requires_grad set to false. I cannot create a copy of the slice and pass this slice to the encoder since the encoder needs the entire input to get a meaningful output. Is there a workaround for this issue?

That is a good question. There are two parts to this:

  1. You’ll want to have a wrapper function that accepts a slice and returns a slice, so that when you pass this wrapper to compute the jacobian, it would only iterate of that subset of inputs/outputs.
  2. Additionally, you’d store the entire input tensor as a global and inplace copy the slice into the global inputs before using that global input tensor as the actual input to the encoder.
    (see my code example above for details)