Gradient per example before and after clipping

Hello,

I wonder if there is an option in Opacus to access the per example gradient before and after clipping during the training please?

Thanks,
Ali

Hi
Accessing per sample gradients before clipping is easy - they’re available between loss.backward() and optimizer.step() calls. Backward pass calculates per sample gradients and stores them in parameter.grad_sample attribute. Optimizer step then does the clipping and aggregation, and cleans up the gradients.
For example:

m = Module()
optimizer = optim.SGD(m.parameters(), <...>)
privacy_engine = PrivacyEngine(<...>)
privacy_engine.attach(optimizer)

<...>
output = m(data)
loss = criterion(data, labels)
loss.backward()
print(m.fc.weight.grad_sample) # print per sample gradients
optimizer.step()

Post-clip values are more tricky - it’s not something we support out of the box.
optimizer.step() does three things at once:

  1. clips per sample gradients
  2. accumulates per sample gradients into parameter.grad
  3. adds noise

Which means that there’s no easy way to access intermediate state after clipping, but before accumulation and noising.

I suppose, the easiest way to get post-clip values would be to take pre-clip values and do the clipping yourself, outside of opacus code.
All you need to do is replicate a small bit from opacus/opacus/per_sample_gradient_clip.py:clip_and_accumulate():

# step 0 : calculate the layer norms
all_norms = calc_sample_norms(
  named_params=self._named_grad_samples(),
)

# step 1: calculate the clipping factors based on the noise
clipping_factor = self.norm_clipper.calc_clipping_factors(all_norms)

You can then simply multiply your pre-clipping per-sample gradients p.grad_sample by clipping_factor tensor to get the same clipping that’s happening inside opacus.

Hope this helps

2 Likes

Thanks @ffuuugor!
I wonder when each layer has both weights and biases, how opacus computes the l2 norm of gradients per example for the clipping? Does it compute separately for biases and weights please?

Actually, by default (and as per Abadi et al.) we clip the L2 norm of the entire gradient vector for a given sample, i.e. vector consisting of gradients for all trainable parameters in the model stacked together.

We’ve tried experimenting with different approaches to clipping, e.g. dynamic per-layer thresholds (see implementations in clipping.py), but it wasn’t too fruitful