Setup: two-head architecture with common base model. The loss is a function of the gradients of the heads’ predicted logit w.r.t. the base model’s out (shared representation); these gradients are computed using autograd.grad
.
Problem: calling backward()
accumulates gradients in the heads’ parameters but not in the base model. I think this is because autograd.grad
does not account for the fact that the shared representation is the output of the base model.
Code: The following snippet reproduces the issue I am facing.
device = 'cpu'
# setup models
base_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,10)).to(device)
head1 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)
head2 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)
# setup data
xb = torch.randn(4, 10).to(device)
# compute feature output (input to both heads)
feat = base_model(xb)
# compute sum of heads' predicted logits
out1 = head1(feat).max(dim=1).values.sum()
out2 = head2(feat).max(dim=1).values.sum()
# compute gradient of head-output w.r.t. shared feature
g1 = torch.autograd.grad(out1, feat, create_graph=True)[0]
g2 = torch.autograd.grad(out2, feat, create_graph=True)[0]
# loss = average of row-wise dot product of gradients
dp = torch.bmm(g1.unsqueeze(1), g2.unsqueeze(2)).squeeze()
loss = dp.mean()
loss.backward()
# check .grad vals of base and head model parameters
def get_avg_grad_norm(model):
vals = [p.grad.norm().item()
for p in model.parameters() if p.grad is not None]
return np.mean(vals)
print ('Base:', get_avg_grad_norm(base_model)) # = 0
print ('Head #1:', get_avg_grad_norm(head1))
print ('Head #2:', get_avg_grad_norm(head2))
Output: average norm of gradient (.grad
) tensors in base
, head1
and head2
respectively. No gradients accumulated in base model parameters:
Base: 0.0
Head #1: 0.029073443884650867
Head #2: 0.020845569670200348
I’m not sure how to fix this in order to ensure that base model parameters have non-zero .grad
tensors.