Norm of output from pytorch3d.transforms.rotation_conversions.matrix_to_axis_angle is not in [0,pi]

I’m noticing that the output of a conversion between a set of rotation matrices to their axis-angle form results in axis-angle vectors which have a norm that is greater than pi, which seems not good. I’m using a modestly altered version of rotation_conversions.py for which I show a diff against Meta’s original.

My questions are:
(A) is this behavior expected?
(B) If it is expected, what is the best way to normalize the magnitude of the vector such that they are always between [0,pi]

Thanks in advance!

Instructions To Reproduce the Issue:

  1. Here is a diff between my rotation_conversions.py which I am using to reproduce this behavior, and what is the original from Meta (meta_rotation_conversions.py).

I created meta_rotation_conversions.py by copy/pasting the code from here: pytorch3d.transforms.rotation_conversions — PyTorch3D documentation into a file I called tmp.py
and then performing sed s:"\[docs\]":'':g tmp.py > meta_rotation_conversions.py

diff rotation_conversions.py meta_rotation_conversions.py 
11a12,13
> from ..common.datatypes import Device
> 
51d52
<     # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
70a72
> 
133,134d134
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
136,137d135
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
139,140d136
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
142,143d137
<             # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
<             #  `int`.
158c152
<         F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
---
>         F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :  # pyre-ignore[16]
161a156
> 
220a216
> 
305a302
> 
307c304
<     n: int, dtype: Optional[torch.dtype] = None, device = None
---
>     n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
324d320
<     # pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
330a327
> 
332c329
<     n: int, dtype: Optional[torch.dtype] = None, device = None
---
>     n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
349a347
> 
351c349
<     dtype: Optional[torch.dtype] = None, device = None
---
>     dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
366a365
> 
381a381
> 
402a403
> 
419a421
> 
436a439
> 
459a463
> 
475a480
> 
491a497
> 
523a530
> 
554a562
> 
578a587
> 
595c604
<     return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
\ No newline at end of file
---
>     return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
  1. Here are the exact commands I ran to produce axis-angle vectors that have norm beyond pi:
import torch 
print(torch.__version__)
import rotation_conversions

data = torch.load('ax_ang_outputs.pt', map_location=torch.device('cpu'))
test_R = data['R_pred'][0][0]

print('This is test rotations')
print(test_R)

# get axis angle from matrix 
ax_ang_from_R = rotation_conversions.matrix_to_axis_angle(test_R)

# check norm 
print(torch.norm(ax_ang_from_R, p=2, dim=-1))
  1. Here’s what I observed in the output
1.9.0
This is test rotations
tensor([[[ 0.1494,  0.4837, -0.8624],
         [ 0.7733, -0.6007, -0.2030],
         [-0.6162, -0.6365, -0.4638]],

        [[-0.0714,  0.4292,  0.9004],
         [ 0.9666,  0.2525, -0.0437],
         [-0.2462,  0.8672, -0.4328]],

        [[ 0.5917,  0.2904, -0.7520],
         [ 0.7764,  0.0455,  0.6285],
         [ 0.2167, -0.9558, -0.1986]],

        [[ 0.0833, -0.4428,  0.8927],
         [-0.8289, -0.5281, -0.1846],
         [ 0.5532, -0.7246, -0.4110]],

        [[-0.1795, -0.3386,  0.9237],
         [-0.7774,  0.6242,  0.0778],
         [-0.6029, -0.7041, -0.3752]],

        [[ 0.5805,  0.8141, -0.0176],
         [ 0.2735, -0.2153, -0.9375],
         [-0.7670,  0.5393, -0.3477]],

        [[ 0.4837, -0.4856, -0.7282],
         [-0.7797, -0.6171, -0.1064],
         [-0.3977,  0.6192, -0.6771]],

        [[-0.6256, -0.7073,  0.3291],
         [-0.5258,  0.6939,  0.4920],
         [-0.5763,  0.1348, -0.8060]],

        [[-0.2255,  0.6738,  0.7036],
         [ 0.4855,  0.7039, -0.5185],
         [-0.8446,  0.2247, -0.4859]],

        [[ 0.6165,  0.3816, -0.6888],
         [-0.3325, -0.6667, -0.6670],
         [-0.7137,  0.6402, -0.2841]],

        [[-0.1925, -0.8007, -0.5672],
         [-0.9634,  0.0443,  0.2645],
         [-0.1867,  0.5974, -0.7799]],

        [[-0.8547, -0.0775, -0.5133],
         [-0.2145,  0.9531,  0.2133],
         [ 0.4727,  0.2924, -0.8313]],

        [[-0.2660,  0.8619,  0.4316],
         [-0.9161, -0.3655,  0.1652],
         [ 0.3001, -0.3515,  0.8868]],

        [[ 0.6904,  0.5019, -0.5209],
         [-0.4371,  0.8633,  0.2525],
         [ 0.5764,  0.0534,  0.8154]],

        [[-0.4372, -0.8915, -0.1191],
         [-0.2856,  0.2632, -0.9215],
         [ 0.8528, -0.3688, -0.3696]],

        [[ 0.4659,  0.7022, -0.5384],
         [ 0.1695,  0.5264,  0.8332],
         [ 0.8684, -0.4795,  0.1262]],

        [[-0.2217, -0.6468, -0.7297],
         [ 0.9226,  0.1031, -0.3716],
         [ 0.3156, -0.7557,  0.5739]],

        [[-0.9733, -0.1339,  0.1866],
         [ 0.0948, -0.9742, -0.2048],
         [ 0.2092, -0.1816,  0.9609]],

        [[-0.2019,  0.9766,  0.0736],
         [-0.3784, -0.1471,  0.9139],
         [ 0.9034,  0.1567,  0.3993]],

        [[ 0.3122,  0.2442, -0.9181],
         [ 0.6225,  0.6774,  0.3919],
         [ 0.7177, -0.6939,  0.0595]],

        [[-0.6140, -0.6281, -0.4779],
         [ 0.7845, -0.4196, -0.4565],
         [ 0.0862, -0.6553,  0.7504]],

        [[-0.8525,  0.3715,  0.3678],
         [-0.2567, -0.9104,  0.3245],
         [ 0.4553,  0.1822,  0.8715]],

        [[ 0.1279,  0.9249, -0.3580],
         [-0.0424,  0.3657,  0.9297],
         [ 0.9909, -0.1038,  0.0860]],

        [[ 0.1305, -0.2174, -0.9673],
         [ 0.9185,  0.3937,  0.0354],
         [ 0.3732, -0.8931,  0.2511]],

        [[-0.8646, -0.5018,  0.0267],
         [ 0.4707, -0.8274, -0.3063],
         [ 0.1758, -0.2522,  0.9516]],

        [[-0.4554,  0.8043,  0.3817],
         [-0.3966, -0.5671,  0.7219],
         [ 0.7971,  0.1774,  0.5772]],

        [[ 0.3839,  0.5596, -0.7345],
         [ 0.3787,  0.6300,  0.6780],
         [ 0.8422, -0.5384,  0.0299]],

        [[-0.3423, -0.6588, -0.6699],
         [ 0.9109, -0.0579, -0.4085],
         [ 0.2303, -0.7501,  0.6200]],

        [[-0.9885,  0.0321,  0.1475],
         [-0.0418, -0.9971, -0.0630],
         [ 0.1450, -0.0684,  0.9871]],

        [[-0.5158,  0.8466,  0.1310],
         [ 0.0590,  0.1877, -0.9805],
         [-0.8547, -0.4980, -0.1467]],

        [[-0.2505,  0.3370, -0.9076],
         [-0.8820, -0.4660,  0.0704],
         [-0.3992,  0.8181,  0.4140]],

        [[-0.2778,  0.1515,  0.9486],
         [-0.1057,  0.9767, -0.1869],
         [-0.9548, -0.1522, -0.2553]],

        [[ 0.6881,  0.7063,  0.1661],
         [ 0.0777,  0.1558, -0.9847],
         [-0.7214,  0.6905,  0.0524]],

        [[ 0.3823, -0.2206,  0.8973],
         [-0.2440, -0.9607, -0.1323],
         [ 0.8912, -0.1684, -0.4211]],

        [[ 0.5232,  0.6215,  0.5831],
         [-0.8507,  0.4215,  0.3141],
         [-0.0506, -0.6604,  0.7492]],

        [[-0.0445,  0.9967, -0.0677],
         [-0.8410, -0.0739, -0.5359],
         [-0.5392,  0.0331,  0.8416]],

        [[ 0.7902, -0.5468, -0.2768],
         [-0.5919, -0.7981, -0.1128],
         [-0.1593,  0.2529, -0.9543]],

        [[-0.4070, -0.7162, -0.5670],
         [-0.7906,  0.5871, -0.1740],
         [ 0.4575,  0.3774, -0.8051]],

        [[-0.3221,  0.9386,  0.1233],
         [ 0.8099,  0.3406, -0.4776],
         [-0.4903, -0.0539, -0.8699]],

        [[ 0.6649, -0.2304,  0.7105],
         [ 0.3991,  0.9137, -0.0771],
         [-0.6314,  0.3348,  0.6994]],

        [[-0.1578, -0.6089, -0.7774],
         [-0.5264, -0.6142,  0.5880],
         [-0.8355,  0.5020, -0.2236]],

        [[ 0.9653, -0.2548, -0.0574],
         [ 0.0103, -0.1827,  0.9831],
         [-0.2610, -0.9496, -0.1737]],

        [[-0.0468,  0.6846,  0.7274],
         [ 0.7944,  0.4669, -0.3884],
         [-0.6055,  0.5597, -0.5658]],

        [[ 0.8492, -0.5221, -0.0788],
         [ 0.0201,  0.1810, -0.9833],
         [ 0.5276,  0.8335,  0.1642]],

        [[ 0.8291, -0.5556, -0.0633],
         [ 0.2867,  0.5196, -0.8048],
         [ 0.4800,  0.6491,  0.5901]],

        [[-0.6534, -0.1525, -0.7415],
         [ 0.6651,  0.3521, -0.6585],
         [ 0.3615, -0.9234, -0.1286]],

        [[-0.1768,  0.9773, -0.1165],
         [-0.4400,  0.0274,  0.8976],
         [ 0.8804,  0.2100,  0.4251]],

        [[ 0.1008, -0.7674, -0.6332],
         [ 0.3219,  0.6273, -0.7091],
         [ 0.9414, -0.1323,  0.3103]],

        [[-0.8693,  0.1492,  0.4712],
         [-0.0189, -0.9627,  0.2700],
         [ 0.4939,  0.2258,  0.8397]],

        [[-0.3912, -0.3320, -0.8583],
         [-0.4360,  0.8882, -0.1448],
         [ 0.8104,  0.3176, -0.4922]],

        [[-0.9286, -0.1917, -0.3177],
         [ 0.3708, -0.4515, -0.8116],
         [ 0.0121, -0.8715,  0.4903]],

        [[ 0.1223,  0.0621, -0.9906],
         [-0.9195, -0.3687, -0.1366],
         [-0.3737,  0.9275,  0.0120]],

        [[ 0.1735,  0.8650,  0.4709],
         [ 0.8893, -0.3431,  0.3025],
         [ 0.4232,  0.3662, -0.8287]],

        [[ 0.4478, -0.6575, -0.6059],
         [ 0.7323,  0.6585, -0.1734],
         [ 0.5130, -0.3661,  0.7764]],

        [[-0.4742, -0.1710,  0.8636],
         [ 0.6836, -0.6896,  0.2389],
         [ 0.5547,  0.7037,  0.4440]],

        [[-0.1458,  0.5584, -0.8167],
         [-0.7840, -0.5686, -0.2489],
         [-0.6034,  0.6040,  0.5207]],

        [[-0.4864,  0.8186,  0.3055],
         [ 0.8697,  0.4870,  0.0800],
         [-0.0833,  0.3046, -0.9488]],

        [[-0.1426, -0.9897,  0.0134],
         [ 0.9831, -0.1432, -0.1144],
         [ 0.1152, -0.0031,  0.9933]]], grad_fn=<SelectBackward>)
tensor([3.4340, 2.2470, 4.4279, 3.5236, 2.0547, 2.0843, 2.7027, 2.6237, 2.0987,
        2.3012, 3.4104, 3.6644, 4.3308, 0.8168, 3.8307, 1.5115, 1.8466, 3.0259,
        4.2176, 1.5462, 2.2674, 3.4726, 1.7825, 1.6834, 2.6265, 3.9048, 1.5489,
        1.9715, 3.1787, 2.4000, 4.0031, 1.8528, 1.6227, 3.1633, 1.2165, 4.5735,
        2.9470, 3.7639, 2.7536, 0.8776, 3.2080, 4.5156, 2.1808, 1.4734, 1.0822,
        3.9158, 4.3418, 1.5516, 3.2293, 4.1915, 2.8080, 2.2360, 3.1000, 1.1137,
        2.6060, 4.0728, 2.9137, 1.7175], grad_fn=<NormBackward1>)

I don’t quite understand the posted diff as it shows some “cosmetic” changes between these files and nothing which would change the actual behavior of any calculation.
Are you getting the expected results using the original file and only your changes create the issue?

Sorry about that, it’s definitely overly complicated. The reason for this is I wanted to just use the individual rotation_conversions.py script rather than installing all of pytorch3d. rotation_conversions.py has one relative dependency, Device from ..common.datatypes, so I removed dependencies on that object and wanted to show exactly how my version differed from the original. But I think you’re right, it is just cosmetic.

To verify it’s not my changes causing this, I have just tried an experiment where copy Meta’s pytorch3d.common.datatypes script into meta_datatypes.py, then make meta_rotation_conversions.py by copying Meta’s original rotation_conversions.py and replacing the import from ..common.datatypes Device to from meta_datatypes import Device. I get exactly the same result.

My code:

import torch 
print(torch.__version__)
import meta_rotation_conversions

data = torch.load('ax_ang_outputs.pt', map_location=torch.device('cpu'))
test_R = data['R_pred'][0][0]

# print('This is test rotations')
print(test_R)

# get axis angle from matrix 
ax_ang_from_R = meta_rotation_conversions.matrix_to_axis_angle(test_R)

# check norm 
print(torch.norm(ax_ang_from_R, p=2, dim=-1))

And the outputs:

1.9.0
tensor([[[ 0.1494,  0.4837, -0.8624],
         [ 0.7733, -0.6007, -0.2030],
         [-0.6162, -0.6365, -0.4638]],

        [[-0.0714,  0.4292,  0.9004],
         [ 0.9666,  0.2525, -0.0437],
         [-0.2462,  0.8672, -0.4328]],

        [[ 0.5917,  0.2904, -0.7520],
         [ 0.7764,  0.0455,  0.6285],
         [ 0.2167, -0.9558, -0.1986]],

        [[ 0.0833, -0.4428,  0.8927],
         [-0.8289, -0.5281, -0.1846],
         [ 0.5532, -0.7246, -0.4110]],

        [[-0.1795, -0.3386,  0.9237],
         [-0.7774,  0.6242,  0.0778],
         [-0.6029, -0.7041, -0.3752]],

        [[ 0.5805,  0.8141, -0.0176],
         [ 0.2735, -0.2153, -0.9375],
         [-0.7670,  0.5393, -0.3477]],

        [[ 0.4837, -0.4856, -0.7282],
         [-0.7797, -0.6171, -0.1064],
         [-0.3977,  0.6192, -0.6771]],

        [[-0.6256, -0.7073,  0.3291],
         [-0.5258,  0.6939,  0.4920],
         [-0.5763,  0.1348, -0.8060]],

        [[-0.2255,  0.6738,  0.7036],
         [ 0.4855,  0.7039, -0.5185],
         [-0.8446,  0.2247, -0.4859]],

        [[ 0.6165,  0.3816, -0.6888],
         [-0.3325, -0.6667, -0.6670],
         [-0.7137,  0.6402, -0.2841]],

        [[-0.1925, -0.8007, -0.5672],
         [-0.9634,  0.0443,  0.2645],
         [-0.1867,  0.5974, -0.7799]],

        [[-0.8547, -0.0775, -0.5133],
         [-0.2145,  0.9531,  0.2133],
         [ 0.4727,  0.2924, -0.8313]],

        [[-0.2660,  0.8619,  0.4316],
         [-0.9161, -0.3655,  0.1652],
         [ 0.3001, -0.3515,  0.8868]],

        [[ 0.6904,  0.5019, -0.5209],
         [-0.4371,  0.8633,  0.2525],
         [ 0.5764,  0.0534,  0.8154]],

        [[-0.4372, -0.8915, -0.1191],
         [-0.2856,  0.2632, -0.9215],
         [ 0.8528, -0.3688, -0.3696]],

        [[ 0.4659,  0.7022, -0.5384],
         [ 0.1695,  0.5264,  0.8332],
         [ 0.8684, -0.4795,  0.1262]],

        [[-0.2217, -0.6468, -0.7297],
         [ 0.9226,  0.1031, -0.3716],
         [ 0.3156, -0.7557,  0.5739]],

        [[-0.9733, -0.1339,  0.1866],
         [ 0.0948, -0.9742, -0.2048],
         [ 0.2092, -0.1816,  0.9609]],

        [[-0.2019,  0.9766,  0.0736],
         [-0.3784, -0.1471,  0.9139],
         [ 0.9034,  0.1567,  0.3993]],

        [[ 0.3122,  0.2442, -0.9181],
         [ 0.6225,  0.6774,  0.3919],
         [ 0.7177, -0.6939,  0.0595]],

        [[-0.6140, -0.6281, -0.4779],
         [ 0.7845, -0.4196, -0.4565],
         [ 0.0862, -0.6553,  0.7504]],

        [[-0.8525,  0.3715,  0.3678],
         [-0.2567, -0.9104,  0.3245],
         [ 0.4553,  0.1822,  0.8715]],

        [[ 0.1279,  0.9249, -0.3580],
         [-0.0424,  0.3657,  0.9297],
         [ 0.9909, -0.1038,  0.0860]],

        [[ 0.1305, -0.2174, -0.9673],
         [ 0.9185,  0.3937,  0.0354],
         [ 0.3732, -0.8931,  0.2511]],

        [[-0.8646, -0.5018,  0.0267],
         [ 0.4707, -0.8274, -0.3063],
         [ 0.1758, -0.2522,  0.9516]],

        [[-0.4554,  0.8043,  0.3817],
         [-0.3966, -0.5671,  0.7219],
         [ 0.7971,  0.1774,  0.5772]],

        [[ 0.3839,  0.5596, -0.7345],
         [ 0.3787,  0.6300,  0.6780],
         [ 0.8422, -0.5384,  0.0299]],

        [[-0.3423, -0.6588, -0.6699],
         [ 0.9109, -0.0579, -0.4085],
         [ 0.2303, -0.7501,  0.6200]],

        [[-0.9885,  0.0321,  0.1475],
         [-0.0418, -0.9971, -0.0630],
         [ 0.1450, -0.0684,  0.9871]],

        [[-0.5158,  0.8466,  0.1310],
         [ 0.0590,  0.1877, -0.9805],
         [-0.8547, -0.4980, -0.1467]],

        [[-0.2505,  0.3370, -0.9076],
         [-0.8820, -0.4660,  0.0704],
         [-0.3992,  0.8181,  0.4140]],

        [[-0.2778,  0.1515,  0.9486],
         [-0.1057,  0.9767, -0.1869],
         [-0.9548, -0.1522, -0.2553]],

        [[ 0.6881,  0.7063,  0.1661],
         [ 0.0777,  0.1558, -0.9847],
         [-0.7214,  0.6905,  0.0524]],

        [[ 0.3823, -0.2206,  0.8973],
         [-0.2440, -0.9607, -0.1323],
         [ 0.8912, -0.1684, -0.4211]],

        [[ 0.5232,  0.6215,  0.5831],
         [-0.8507,  0.4215,  0.3141],
         [-0.0506, -0.6604,  0.7492]],

        [[-0.0445,  0.9967, -0.0677],
         [-0.8410, -0.0739, -0.5359],
         [-0.5392,  0.0331,  0.8416]],

        [[ 0.7902, -0.5468, -0.2768],
         [-0.5919, -0.7981, -0.1128],
         [-0.1593,  0.2529, -0.9543]],

        [[-0.4070, -0.7162, -0.5670],
         [-0.7906,  0.5871, -0.1740],
         [ 0.4575,  0.3774, -0.8051]],

        [[-0.3221,  0.9386,  0.1233],
         [ 0.8099,  0.3406, -0.4776],
         [-0.4903, -0.0539, -0.8699]],

        [[ 0.6649, -0.2304,  0.7105],
         [ 0.3991,  0.9137, -0.0771],
         [-0.6314,  0.3348,  0.6994]],

        [[-0.1578, -0.6089, -0.7774],
         [-0.5264, -0.6142,  0.5880],
         [-0.8355,  0.5020, -0.2236]],

        [[ 0.9653, -0.2548, -0.0574],
         [ 0.0103, -0.1827,  0.9831],
         [-0.2610, -0.9496, -0.1737]],

        [[-0.0468,  0.6846,  0.7274],
         [ 0.7944,  0.4669, -0.3884],
         [-0.6055,  0.5597, -0.5658]],

        [[ 0.8492, -0.5221, -0.0788],
         [ 0.0201,  0.1810, -0.9833],
         [ 0.5276,  0.8335,  0.1642]],

        [[ 0.8291, -0.5556, -0.0633],
         [ 0.2867,  0.5196, -0.8048],
         [ 0.4800,  0.6491,  0.5901]],

        [[-0.6534, -0.1525, -0.7415],
         [ 0.6651,  0.3521, -0.6585],
         [ 0.3615, -0.9234, -0.1286]],

        [[-0.1768,  0.9773, -0.1165],
         [-0.4400,  0.0274,  0.8976],
         [ 0.8804,  0.2100,  0.4251]],

        [[ 0.1008, -0.7674, -0.6332],
         [ 0.3219,  0.6273, -0.7091],
         [ 0.9414, -0.1323,  0.3103]],

        [[-0.8693,  0.1492,  0.4712],
         [-0.0189, -0.9627,  0.2700],
         [ 0.4939,  0.2258,  0.8397]],

        [[-0.3912, -0.3320, -0.8583],
         [-0.4360,  0.8882, -0.1448],
         [ 0.8104,  0.3176, -0.4922]],

        [[-0.9286, -0.1917, -0.3177],
         [ 0.3708, -0.4515, -0.8116],
         [ 0.0121, -0.8715,  0.4903]],

        [[ 0.1223,  0.0621, -0.9906],
         [-0.9195, -0.3687, -0.1366],
         [-0.3737,  0.9275,  0.0120]],

        [[ 0.1735,  0.8650,  0.4709],
         [ 0.8893, -0.3431,  0.3025],
         [ 0.4232,  0.3662, -0.8287]],

        [[ 0.4478, -0.6575, -0.6059],
         [ 0.7323,  0.6585, -0.1734],
         [ 0.5130, -0.3661,  0.7764]],

        [[-0.4742, -0.1710,  0.8636],
         [ 0.6836, -0.6896,  0.2389],
         [ 0.5547,  0.7037,  0.4440]],

        [[-0.1458,  0.5584, -0.8167],
         [-0.7840, -0.5686, -0.2489],
         [-0.6034,  0.6040,  0.5207]],

        [[-0.4864,  0.8186,  0.3055],
         [ 0.8697,  0.4870,  0.0800],
         [-0.0833,  0.3046, -0.9488]],

        [[-0.1426, -0.9897,  0.0134],
         [ 0.9831, -0.1432, -0.1144],
         [ 0.1152, -0.0031,  0.9933]]], grad_fn=<SelectBackward>)
tensor([3.4340, 2.2470, 4.4279, 3.5236, 2.0547, 2.0843, 2.7027, 2.6237, 2.0987,
        2.3012, 3.4104, 3.6644, 4.3308, 0.8168, 3.8307, 1.5115, 1.8466, 3.0259,
        4.2176, 1.5462, 2.2674, 3.4726, 1.7825, 1.6834, 2.6265, 3.9048, 1.5489,
        1.9715, 3.1787, 2.4000, 4.0031, 1.8528, 1.6227, 3.1633, 1.2165, 4.5735,
        2.9470, 3.7639, 2.7536, 0.8776, 3.2080, 4.5156, 2.1808, 1.4734, 1.0822,
        3.9158, 4.3418, 1.5516, 3.2293, 4.1915, 2.8080, 2.2360, 3.1000, 1.1137,
        2.6060, 4.0728, 2.9137, 1.7175], grad_fn=<NormBackward1>)

Thanks for the follow-up. So the issue is not caused by your changes in the method, since you are also seeing the same outputs using the original code.
From the docs it seems that matrix_to_angle returns:

… where the magnitude is the angle turned anticlockwise in radians around the vector’s direction

an output in radians. In your code you are then calculating the torch.norm (L2-norm) in the last dimension, which contain the angles in radians and which can return values outside of [0, PI].
I’m not sure what exactly you are trying to calculate (maybe the vector norm), so could you explain your use case a bit more and why the results are not expected?

Thanks @ptrblck – I’m chiming in because @David_Juergens and I have been working through this offline and we now understand the origin of this mismatch with our expectations.

It is common convention (e.g. in scipy scipy.spatial.transform.Rotation.as_rotvec — SciPy v1.9.1 Manual) for rotation vectors to have norm at most pi. The reason for this is that if the norm is allowed to be in [0, 2*pi], the representation is not unique. The larger set of possible rotation vectors forms a double cover of the SO(3) group.

For example, imagine v=[1+pi, 0, 0] is a rotation vector. This vector is equivalent to a second rotation vector w=[1-pi, 0, 0] whose norm is less than 1.

This matters in the context of the machine learning application we are working on, because our goal is to predict a rotation vector with norm at most pi. Imposing a loss on v versus w makes practical difference.