Problem: calling backward() accumulates gradients in the heads’ parameters but not in the base model. I think this is because autograd.grad does not account for the fact that the shared representation is the output of the base model.

Code: The following snippet reproduces the issue I am facing.

device = 'cpu'

# setup models
base_model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,10)).to(device)
head1 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)
head2 = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10,2)).to(device)

# setup data
xb = torch.randn(4, 10).to(device)

# compute feature output (input to both heads)
feat = base_model(xb)

# compute sum of heads' predicted logits

# loss = average of row-wise dot product of gradients
dp = torch.bmm(g1.unsqueeze(1), g2.unsqueeze(2)).squeeze()
loss = dp.mean()
loss.backward()

for p in model.parameters() if p.grad is not None]
return np.mean(vals)

print ('Base:', get_avg_grad_norm(base_model)) # = 0

Base: 0.0

I’m not sure how to fix this in order to ensure that base model parameters have non-zero .grad tensors.

Hi,

I’m not sure what you are trying to achieve.

When you compute the gradient of head-output w.r.t. shared feature, $$\frac{\partial g1}{feat}$$ does not depend on the weights of base_model.
Hence, the gradient should be 0.

Hi Victor,

I’m trying to implement a regulariser described in this paper: [2105.05612] Evading the Simplicity Bias: Training a Diverse Set of Models Discovers Solutions with Superior OOD Generalization.

The shared features are a function of the input and model parameters, so the gradient of head output w.r.t. shared features is a function of the model parameters as well.

Anyways, I was able to fix the issue. I was using ReLU activation in the MLP heads, and the second derivative of the loss w.r.t. the shared feature representation involves taking the double derivative of the ReLU function, which is always zero. Changing the activations to Softplus (which has non-zero second derivative) resolved this issue—the gradients now back-propagate to the base model parameters as expected.

1 Like