Setup: two-head architecture with common base model. The loss is a function of the gradients of the heads’ predicted logit w.r.t. the base model’s out (shared representation); these gradients are computed using
backward() accumulates gradients in the heads’ parameters but not in the base model. I think this is because
autograd.grad does not account for the fact that the shared representation is the output of the base model.
Code: The following snippet reproduces the issue I am facing.
device = 'cpu' # setup models base_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,10)).to(device) head1 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device) head2 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device) # setup data xb = torch.randn(4, 10).to(device) # compute feature output (input to both heads) feat = base_model(xb) # compute sum of heads' predicted logits out1 = head1(feat).max(dim=1).values.sum() out2 = head2(feat).max(dim=1).values.sum() # compute gradient of head-output w.r.t. shared feature g1 = torch.autograd.grad(out1, feat, create_graph=True) g2 = torch.autograd.grad(out2, feat, create_graph=True) # loss = average of row-wise dot product of gradients dp = torch.bmm(g1.unsqueeze(1), g2.unsqueeze(2)).squeeze() loss = dp.mean() loss.backward() # check .grad vals of base and head model parameters def get_avg_grad_norm(model): vals = [p.grad.norm().item() for p in model.parameters() if p.grad is not None] return np.mean(vals) print ('Base:', get_avg_grad_norm(base_model)) # = 0 print ('Head #1:', get_avg_grad_norm(head1)) print ('Head #2:', get_avg_grad_norm(head2))
Output: average norm of gradient (
.grad) tensors in
head2 respectively. No gradients accumulated in base model parameters:
Base: 0.0 Head #1: 0.029073443884650867 Head #2: 0.020845569670200348
I’m not sure how to fix this in order to ensure that base model parameters have non-zero