Referring to this forum post here which is the most relevant answer chain I have found to my question: How to Penalize Norm of End-to-End Jacobian
I am trying to implement a knowledge distillation loss which penalises the Jacobian of the output wrt the input for teacher and student. However, my model is not training correctly and I think it is because the computational graph is not what I think it should be. I have tried two things:
- Using backward with create_graph = True
def get_jacobian(output, x, batch_size, input_dim, output_dim):
"""In order to keep grads, need to set create_graph to True, but computation very slow."""
assert x.requires_grad
jacobian = torch.zeros(batch_size, output_dim, input_dim, device=x.device)
for i in range(output_dim):
grad_output = torch.zeros(batch_size, output_dim, device=x.device)
grad_output[:, i] = 1
output.backward(grad_output, create_graph=True)
jacobian[:, i, :] = x.grad.view(batch_size, -1)
x.grad.zero_()
print("Jacobian for backward requires grad:", jacobian.requires_grad)
return jacobian.view(batch_size, -1) # Flatten
Checking if I leave only retain_graph = True shows that the result no longer requires grad. At this point, I can force it to require grad by
jacobian.requires_grad = True
jacobian.retain_grad()
but I believe this does not attach the Jacobian back onto the computational graph (or does it?)
- Autograd with retain_graph = True
def get_jacobian(output, x, batch_size, input_dim, output_dim):
jacobian = torch.zeros(batch_size, output_dim, input_dim, device=x.device)
for i in range(output_dim):
grad_output = torch.zeros(batch_size, output_dim, device=x.device)
grad_output[:, i] = 1
grad_input = torch.autograd.grad(output, x, grad_output, retain_graph=True)[0]
jacobian[:, i, :] = grad_input.view(batch_size, input_dim)
print("Jacobian as calculated requires grad:", jacobian.requires_grad)
return jacobian.view(batch_size, -1) # Flatten
It’s also very hard to debug the cases where create_graph = True since the training takes so long. Can someone help me with a method that is guaranteed to retain the grads to the eventual loss function? Here is the loss computation code:
def jacobian_loss(scores, targets, inputs, T, alpha, batch_size, loss_fn, input_dim, output_dim):
"""Eq 10, no hard targets used.
Args:
scores: logits of student [batch_size, num_classes]
targets: logits of teacher [batch_size, num_classes]
inputs: [batch_size, input_dim, input_dim, channels]
T: float, temperature
alpha: float, weight of jacobian penalty
loss_fn: for classical distill loss - MSE, BCE, KLDiv
"""
# Only compute softmax for KL divergence loss
if isinstance(loss_fn, nn.KLDivLoss):
soft_pred = F.log_softmax(scores/T, dim=1)
soft_targets = F.log_softmax(targets/T, dim=1) # Have set log_target=True
elif isinstance(loss_fn, nn.CrossEntropyLoss):
soft_pred = scores/T
soft_targets = F.softmax(targets/T).argmax(dim=1)
elif isinstance(loss_fn, nn.MSELoss):
soft_pred = F.softmax(scores/T, dim=1)
soft_targets = F.softmax(targets/T, dim=1)
else:
raise ValueError("Loss function not supported.")
# t_jac = get_jacobian(targets, inputs, batch_size, input_dim, output_dim)
# s_jac = get_jacobian(scores, inputs, batch_size, input_dim, output_dim)
t_jac = get_approx_jacobian(targets, inputs, batch_size, input_dim, output_dim)
i = torch.argmax(targets, dim=1)
s_jac = get_approx_jacobian(scores, inputs, batch_size, input_dim, output_dim, i)
s_jac= torch.div(s_jac, torch.norm(s_jac, 2, dim=-1).unsqueeze(1))
t_jac = torch.div(t_jac, torch.norm(t_jac, 2, dim=-1).unsqueeze(1))
jacobian_loss = torch.norm(t_jac-s_jac, 2, dim=1)
jacobian_loss = torch.mean(jacobian_loss) # Batchwise reduction
distill_loss = loss_fn(soft_pred, soft_targets)
loss = (1-alpha) * distill_loss + alpha * jacobian_loss
return loss
Update:
I believe using create_graph = True for both autograd and backward works better than retain_graph = True, but I am not sure why.
Update on update: It appears I have resolved this after some fiddling and as long as create_graph = True, the gradients propagate correctly. The same is not true for my other attempts to extract feature maps using hooks and use those in the loss function. I would still like to understand how this all works, though, so if someone could explain to me what is going on it would be great.