How to partially flatten a structure, retaining some of the nested structure

I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!

Line 1158:

Yes, torch.flatten has a start_dim and a end_dim parameter.

Example:

t = torch.rand((3, 4, 5, 6, 7, 8))

print(t.shape)
t = torch.flatten(t, start_dim=2, end_dim=4)
print(t.shape)

Output:

torch.Size([3, 4, 5, 6, 7, 8])
torch.Size([3, 4, 210, 8])

Hi! Thanks for your answer.

I don’t fully understand, though: in the TensorFlow function, you can pass two tensors and that will flatten the first up to the the second one. So basically you don’t use explicitly the dimensions, you just pass the structures. In your example however, I wouldn’t know how to flatten tensor t to another structure.

1 Like

Hi,

Click here for a long and useless explanation

I’m going to use the first example from this link.

input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]

flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)

# Output is:
# [[2, 2], [3, 3], [4, 9], [5, 5]]
# [True, True, False, True]

So this is what we are given in the example.

The input looks like this:

         /-2
      /-|
     |   \-2
   /-|
  |  |   /-3
  |   \-|
  |      \-3
--|
  |      /-4
  |   /-|
  |  |   \-9
   \-|
     |   /-5
      \-|
         \-5

With shape torch.Size([2, 2, 2]) (If we convert this to a torch tensor)

The shallow tree looks like this:

      /-True
   /-|
  |   \-True
--|
  |   /-False
   \-|
      \-True

With shape torch.Size([2, 2]) (also already converted to torch tensor).

And the expected tree after using the flatten_up_to function is:

      /-2
   /-|
  |   \-2
  |
  |   /-3
  |--|
  |   \-3
--|
  |   /-4
  |--|
  |   \-9
  |
  |   /-5
   \-|
      \-5

With shape torch.Size([4, 2]).
This means we have to flatten from the left to match the depth of the tree.

Original depth: 3
Shallow depth: 2

We want to “get rid” of 1 dimension, so that they match.

We can achieve this by using

t = torch.flatten(input_tree, end_dim=1)

To make this more generic, we can create our own function to take care of this:

def flatten_up_to(shallow_tree, input_tree):
    len_i = len(input_tree.shape)
    len_s = len(shallow_tree.shape)
    d = len_i - len_s
    assert d > 0, "Shallow tree must be smaller than the input tree."

    return torch.flatten(input_tree, end_dim=d)

You can try and use this and see if it works for you. If this is not the behavior you were expecting, please let me know, perhaps I did not understand correctly what you want/need or how this flatten_up_to is supposed to work :smile:

Hope this helps!

1 Like