I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!
Line 1158:
TypeError: If `check_types` is not `False` and the two structures differ in
the type of sequence in any of their substructures.
ValueError: If no structures are provided.
"""
return map_structure_with_tuple_paths_up_to(structure[0],
func,
*structure,
**kwargs)
def _yield_flat_up_to(shallow_tree, input_tree, is_nested_fn, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
is_nested_fn: Function used to test if a value should be treated as a
nested structure.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
zouzias
(Anastasios Zouzias)
May 11, 2022, 7:33am
2
You can take a look at the pytree
abstraction that is used internally by PyTorch: pytorch/_pytree.py at v1.11.0 · pytorch/pytorch · GitHub
Here is a nice overview of pytrees, if you haven’t seen it before: Pytrees — JAX documentation