Although functorch is still in beta, I am curious about the implications for Opacus. Do you have any plans to use functorch? Do you think it can improve the efficiency of the GradSampler compared to the existing hand-crafted gradient hooks? (in addition to making it easier to support different types of layers, of course)
Very good question! We were waiting for functorch release to assess the current state, so will investigate now.
If you have any insights or thoughts - they’d be very welcome
As an aside, we are looking for cases where Opacus is much slower than JAX. In the benchmarks we know, Opacus 1.0 often comes very close to JAX in terms of speed (at most 20% slower) despite not enjoying the power of vmap.
If you know of cases where that’s not true, we would be happy to take a look .