Input-output Jacobian difference as loss function

Referring to this forum post here which is the most relevant answer chain I have found to my question: How to Penalize Norm of End-to-End Jacobian

I am trying to implement a knowledge distillation loss which penalises the Jacobian of the output wrt the input for teacher and student. However, my model is not training correctly and I think it is because the computational graph is not what I think it should be. I have tried two things:

  1. Using backward with create_graph = True
def get_jacobian(output, x, batch_size, input_dim, output_dim):
    """In order to keep grads, need to set create_graph to True, but computation very slow."""
    assert x.requires_grad
    jacobian = torch.zeros(batch_size, output_dim, input_dim, device=x.device)
    for i in range(output_dim):
        grad_output = torch.zeros(batch_size, output_dim, device=x.device)
        grad_output[:, i] = 1
        output.backward(grad_output, create_graph=True)
        jacobian[:, i, :] = x.grad.view(batch_size, -1)
        x.grad.zero_()
    print("Jacobian for backward requires grad:", jacobian.requires_grad)
    return jacobian.view(batch_size, -1) # Flatten

Checking if I leave only retain_graph = True shows that the result no longer requires grad. At this point, I can force it to require grad by

jacobian.requires_grad = True
jacobian.retain_grad()

but I believe this does not attach the Jacobian back onto the computational graph (or does it?)

  1. Autograd with retain_graph = True
def get_jacobian(output, x, batch_size, input_dim, output_dim):
    jacobian = torch.zeros(batch_size, output_dim, input_dim, device=x.device)
    for i in range(output_dim):
        grad_output = torch.zeros(batch_size, output_dim, device=x.device)
        grad_output[:, i] = 1
        grad_input = torch.autograd.grad(output, x, grad_output, retain_graph=True)[0]
        jacobian[:, i, :] = grad_input.view(batch_size, input_dim)
    print("Jacobian as calculated requires grad:", jacobian.requires_grad)
    return jacobian.view(batch_size, -1) # Flatten

It’s also very hard to debug the cases where create_graph = True since the training takes so long. Can someone help me with a method that is guaranteed to retain the grads to the eventual loss function? Here is the loss computation code:

def jacobian_loss(scores, targets, inputs, T, alpha, batch_size, loss_fn, input_dim, output_dim):
    """Eq 10, no hard targets used.
    Args:
        scores: logits of student [batch_size, num_classes]
        targets: logits of teacher [batch_size, num_classes]
        inputs: [batch_size, input_dim, input_dim, channels]
        T: float, temperature
        alpha: float, weight of jacobian penalty
        loss_fn: for classical distill loss - MSE, BCE, KLDiv
    """
    # Only compute softmax for KL divergence loss
    if isinstance(loss_fn, nn.KLDivLoss):
        soft_pred = F.log_softmax(scores/T, dim=1)
        soft_targets = F.log_softmax(targets/T, dim=1)  # Have set log_target=True
    elif isinstance(loss_fn, nn.CrossEntropyLoss):
        soft_pred = scores/T
        soft_targets = F.softmax(targets/T).argmax(dim=1)
    elif isinstance(loss_fn, nn.MSELoss):
        soft_pred = F.softmax(scores/T, dim=1)
        soft_targets = F.softmax(targets/T, dim=1)
    else:
        raise ValueError("Loss function not supported.")
        
    # t_jac = get_jacobian(targets, inputs, batch_size, input_dim, output_dim)
    # s_jac = get_jacobian(scores, inputs, batch_size, input_dim, output_dim)

    t_jac = get_approx_jacobian(targets, inputs, batch_size, input_dim, output_dim)
    i = torch.argmax(targets, dim=1)
    s_jac = get_approx_jacobian(scores, inputs, batch_size, input_dim, output_dim, i)
    s_jac= torch.div(s_jac, torch.norm(s_jac, 2, dim=-1).unsqueeze(1))
    t_jac = torch.div(t_jac, torch.norm(t_jac, 2, dim=-1).unsqueeze(1))
    jacobian_loss = torch.norm(t_jac-s_jac, 2, dim=1)
    jacobian_loss = torch.mean(jacobian_loss)  # Batchwise reduction
    distill_loss = loss_fn(soft_pred, soft_targets)
    loss = (1-alpha) * distill_loss + alpha * jacobian_loss
    return  loss

Update:
I believe using create_graph = True for both autograd and backward works better than retain_graph = True, but I am not sure why.

Update on update: It appears I have resolved this after some fiddling and as long as create_graph = True, the gradients propagate correctly. The same is not true for my other attempts to extract feature maps using hooks and use those in the loss function. I would still like to understand how this all works, though, so if someone could explain to me what is going on it would be great.