I want to use nn.init.orthogonal_ get orthogonal matrix. For 2 dimensions, such as this code:
torch.nn.init.orthogonal_(torch.empty(2, 2))
tensor([[ 0.8164, 0.5775],
[ 0.5775, -0.8164]])
But for more than 2 dimensions, such as this:
torch.nn.init.orthogonal_(torch.empty(3, 2, 2))
tensor([[[ 0.1439, 0.1052],
[ 0.1259, -0.9759]],
[[-0.4479, -0.0665],
[ 0.8906, 0.0417]],
[[ 0.1099, -0.9895],
[-0.0143, -0.0923]]])
obviously, for every 2 * 2 matrix not orthogonal matrix, but I want get three 2 * 2 orthogonal matrix in the above code. I read the document torch.nn.init.orthogonal_ , but I can not understand the mean of flattened. Anyone can help me?