What is the difference of .flatten() and .view(-1) in PyTorch?

Both .flatten() and .view(-1) flattens a tensor in PyTorch. What’s the difference?

  1. Does .flatten() copy data of the tensor?
  2. Is .view(-1) faster?
  3. Is there any situation that .flatten() doesn’t work?

I’ve tried to read PyTorch’s docs but it doesn’t answer these questions.

6 Likes

view ( *shape ) → Tensor
https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
Returns a new tensor with the same data as the self tensor but of a different shape .

torch. flatten ( input , start_dim=0 , end_dim=-1 ) → Tensor
https://pytorch.org/docs/stable/torch.html#torch.flatten
Flattens a contiguous range of dims in a tensor.

Well I read those descriptions on the docs but they didn’t answer my questions.

6 Likes
  1. No, torch.flatten() function does not copy any data, and actually it behaves more like a wrapper around the view() function. Simple way to prove it without having any explicit mention of it in the docs is by running the following lines of code:

    # Create (2, 3, 4) shape data tensor filled with 0.
    a = torch.zeros(2, 3, 4)
    
    # Flatten 2nd and 3rd dimensions of the original data 
    # tensor using `view` and `flatten` methods.
    b = a.view(2, 12)
    c = torch.flatten(a, start_dim=1)
    
    # Change a distinct value in each flattened tensor object.
    b[0, 2] = 1
    c[0, 4] = 2
    
    # Compare tensors objects data to each other to look for 
    # any mismatches.
    print("Tensors A and B data match?", all(a.view(-1) == b.view(-1)))
    print("Tensors A and C data match?", all(a.view(-1) == c.view(-1)))
    print("Tensors B and C data match?", all(b.view(-1) == c.view(-1)))
    

    Output:

    Tensors A and B data match? True
    Tensors A and C data match? True
    Tensors B and C data match? True
    
  2. Yes, but the difference is negligible in practice. The overhead that flatten() function introduces is only from its internal simple computation of the tensor’s output shape and the actual call to the view() method or similar. This difference is in less than 1μs.

  3. Not any that I would know about.

15 Likes