I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!
You can take a look at the
pytree abstraction that is used internally by PyTorch: pytorch/_pytree.py at v1.11.0 · pytorch/pytorch · GitHub
Here is a nice overview of pytrees, if you haven’t seen it before: Pytrees — JAX documentation