# How to partially flatten a structure, retaining some of the nested structure

I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!

Line 1158:

Yes, torch.flatten has a `start_dim` and a `end_dim` parameter.

Example:

``````t = torch.rand((3, 4, 5, 6, 7, 8))

print(t.shape)
t = torch.flatten(t, start_dim=2, end_dim=4)
print(t.shape)
``````

Output:

``````torch.Size([3, 4, 5, 6, 7, 8])
torch.Size([3, 4, 210, 8])
``````

I don’t fully understand, though: in the TensorFlow function, you can pass two tensors and that will flatten the first up to the the second one. So basically you don’t use explicitly the dimensions, you just pass the structures. In your example however, I wouldn’t know how to flatten tensor t to another structure.

1 Like

Hi,

I’m going to use the first example from this link.

``````input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]

flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)

# Output is:
# [[2, 2], [3, 3], [4, 9], [5, 5]]
# [True, True, False, True]
``````

So this is what we are given in the example.

The input looks like this:

``````         /-2
/-|
|   \-2
/-|
|  |   /-3
|   \-|
|      \-3
--|
|      /-4
|   /-|
|  |   \-9
\-|
|   /-5
\-|
\-5
``````

With shape `torch.Size([2, 2, 2])` (If we convert this to a torch tensor)

The shallow tree looks like this:

``````      /-True
/-|
|   \-True
--|
|   /-False
\-|
\-True
``````

With shape `torch.Size([2, 2])` (also already converted to torch tensor).

And the expected tree after using the `flatten_up_to` function is:

``````      /-2
/-|
|   \-2
|
|   /-3
|--|
|   \-3
--|
|   /-4
|--|
|   \-9
|
|   /-5
\-|
\-5
``````

With shape `torch.Size([4, 2])`.
This means we have to flatten from the left to match the depth of the tree.

Original depth: 3
Shallow depth: 2

We want to “get rid” of 1 dimension, so that they match.

We can achieve this by using

``````t = torch.flatten(input_tree, end_dim=1)
``````

To make this more generic, we can create our own function to take care of this:

``````def flatten_up_to(shallow_tree, input_tree):
len_i = len(input_tree.shape)
len_s = len(shallow_tree.shape)
d = len_i - len_s
assert d > 0, "Shallow tree must be smaller than the input tree."

You can try and use this and see if it works for you. If this is not the behavior you were expecting, please let me know, perhaps I did not understand correctly what you want/need or how this `flatten_up_to` is supposed to work 