Optimizing a loss based on autograd.grad output

Setup: two-head architecture with common base model. The loss is a function of the gradients of the heads’ predicted logit w.r.t. the base model’s out (shared representation); these gradients are computed using autograd.grad.

Problem: calling backward() accumulates gradients in the heads’ parameters but not in the base model. I think this is because autograd.grad does not account for the fact that the shared representation is the output of the base model.

Code: The following snippet reproduces the issue I am facing.

device = 'cpu'

# setup models
base_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,10)).to(device)
head1 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)
head2 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)

# setup data
xb = torch.randn(4, 10).to(device)

# compute feature output (input to both heads)
feat = base_model(xb)

# compute sum of heads' predicted logits
out1 = head1(feat).max(dim=1).values.sum()
out2 = head2(feat).max(dim=1).values.sum()

# compute gradient of head-output w.r.t. shared feature
g1 = torch.autograd.grad(out1, feat, create_graph=True)[0]
g2 = torch.autograd.grad(out2, feat, create_graph=True)[0]

# loss = average of row-wise dot product of gradients 
dp = torch.bmm(g1.unsqueeze(1), g2.unsqueeze(2)).squeeze()
loss = dp.mean()
loss.backward()


# check .grad vals of base and head model parameters
def get_avg_grad_norm(model):
    vals = [p.grad.norm().item()
            for p in model.parameters() if p.grad is not None]
    return np.mean(vals)

print ('Base:', get_avg_grad_norm(base_model)) # = 0
print ('Head #1:', get_avg_grad_norm(head1))
print ('Head #2:', get_avg_grad_norm(head2))

Output: average norm of gradient (.grad) tensors in base, head1 and head2 respectively. No gradients accumulated in base model parameters:

Base: 0.0
Head #1: 0.029073443884650867
Head #2: 0.020845569670200348

I’m not sure how to fix this in order to ensure that base model parameters have non-zero .grad tensors.

Hi,

I’m not sure what you are trying to achieve.

When you compute the gradient of head-output w.r.t. shared feature, $$\frac{\partial g1}{feat}$$ does not depend on the weights of base_model.
Hence, the gradient should be 0.

Hi Victor,

I’m trying to implement a regulariser described in this paper: [2105.05612] Evading the Simplicity Bias: Training a Diverse Set of Models Discovers Solutions with Superior OOD Generalization.

The shared features are a function of the input and model parameters, so the gradient of head output w.r.t. shared features is a function of the model parameters as well.

Anyways, I was able to fix the issue. I was using ReLU activation in the MLP heads, and the second derivative of the loss w.r.t. the shared feature representation involves taking the double derivative of the ReLU function, which is always zero. Changing the activations to Softplus (which has non-zero second derivative) resolved this issue—the gradients now back-propagate to the base model parameters as expected.

1 Like