I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!
Line 1158:
I was wondering if PyTorch has the equivalent of the TensorFlow “flatten_up_to” function. This allows to partially flatten a structure, retaining some of the nested structure. Thanks!
Line 1158:
Yes, torch.flatten has a start_dim
and a end_dim
parameter.
Example:
t = torch.rand((3, 4, 5, 6, 7, 8))
print(t.shape)
t = torch.flatten(t, start_dim=2, end_dim=4)
print(t.shape)
Output:
torch.Size([3, 4, 5, 6, 7, 8])
torch.Size([3, 4, 210, 8])
Hi! Thanks for your answer.
I don’t fully understand, though: in the TensorFlow function, you can pass two tensors and that will flatten the first up to the the second one. So basically you don’t use explicitly the dimensions, you just pass the structures. In your example however, I wouldn’t know how to flatten tensor t to another structure.
Hi,
I’m going to use the first example from this link.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
# Output is:
# [[2, 2], [3, 3], [4, 9], [5, 5]]
# [True, True, False, True]
So this is what we are given in the example.
The input looks like this:
/2
/
 \2
/
  /3
 \
 \3

 /4
 /
  \9
\
 /5
\
\5
With shape torch.Size([2, 2, 2])
(If we convert this to a torch tensor)
The shallow tree looks like this:
/True
/
 \True

 /False
\
\True
With shape torch.Size([2, 2])
(also already converted to torch tensor).
And the expected tree after using the flatten_up_to
function is:
/2
/
 \2

 /3

 \3

 /4

 \9

 /5
\
\5
With shape torch.Size([4, 2])
.
This means we have to flatten from the left to match the depth of the tree.
Original depth: 3
Shallow depth: 2
We want to “get rid” of 1 dimension, so that they match.
We can achieve this by using
t = torch.flatten(input_tree, end_dim=1)
To make this more generic, we can create our own function to take care of this:
def flatten_up_to(shallow_tree, input_tree):
len_i = len(input_tree.shape)
len_s = len(shallow_tree.shape)
d = len_i  len_s
assert d > 0, "Shallow tree must be smaller than the input tree."
return torch.flatten(input_tree, end_dim=d)
You can try and use this and see if it works for you. If this is not the behavior you were expecting, please let me know, perhaps I did not understand correctly what you want/need or how this flatten_up_to
is supposed to work
Hope this helps!