Functorch's per-sample gradients and Opacus

Hi, PyTorch recently released the first version of functorch. This is exciting because functorch makes it easy to compute per-sample gradients, like in JAX. They even have a dedicated tutorial about per-sample gradients with functorch.

Although functorch is still in beta, I am curious about the implications for Opacus. Do you have any plans to use functorch? Do you think it can improve the efficiency of the GradSampler compared to the existing hand-crafted gradient hooks? (in addition to making it easier to support different types of layers, of course)

Thanks!

Very good question! We were waiting for functorch release to assess the current state, so will investigate now.
If you have any insights or thoughts - they’d be very welcome

Thanks, that’s good to hear!

As an aside, we are looking for cases where Opacus is much slower than JAX. In the benchmarks we know, Opacus 1.0 often comes very close to JAX in terms of speed (at most 20% slower) despite not enjoying the power of vmap.

If you know of cases where that’s not true, we would be happy to take a look :slight_smile:.

1 Like