[HELP] Convert TensorFlow Flatten layer into PyTorch

Hi, I have this layer in TensorFlow:

Flatten()

and I’m trying to convert it into PyTorch, that’s my attempt:

nn.Flatten(2)

In TF i have this dimensionality reduction:
(None, 1, 1, 64) -> Flatten() -> (None, 64)
but in PyTorch what I achieve is:
(-1, 64, 1, 1) -> Flatten() -> (-1, 64, 1)

What am I doing wrong?
Thanks!

I just figured out how the flatten layer in PyTorch works, so basically as firs argument i need to express from where to start the flattening ad as second where to end it. Basically it should be:

nn.Flatten(1, 3)