How to use nn.init.orthogonal_ with more than 2 dimensions

I want to use nn.init.orthogonal_ get orthogonal matrix. For 2 dimensions, such as this code:

torch.nn.init.orthogonal_(torch.empty(2, 2))
tensor([[ 0.8164,  0.5775],
        [ 0.5775, -0.8164]])

But for more than 2 dimensions, such as this:

torch.nn.init.orthogonal_(torch.empty(3, 2, 2))
tensor([[[ 0.1439,  0.1052],
         [ 0.1259, -0.9759]],

        [[-0.4479, -0.0665],
         [ 0.8906,  0.0417]],

        [[ 0.1099, -0.9895],
         [-0.0143, -0.0923]]])

obviously, for every 2 * 2 matrix not orthogonal matrix, but I want get three 2 * 2 orthogonal matrix in the above code. I read the document torch.nn.init.orthogonal_ , but I can not understand the mean of flattened. Anyone can help me?

You can check the source here but it basically does inp = inp.view(inp.size(0), -1) And then optionally transpose to make sure to be able to find an orthogonal matrix.
If you want many 2x2 orthogonal matrices, you might have to call orthogonal_ multiple times.