How to partially flatten a structure, retaining some of the nested structure?

I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!

Line 1158:

You can take a look at the pytree abstraction that is used internally by PyTorch: pytorch/_pytree.py at v1.11.0 · pytorch/pytorch · GitHub

Here is a nice overview of pytrees, if you haven’t seen it before: Pytrees — JAX documentation