I’m noticing that the output of a conversion between a set of rotation matrices to their axis-angle form results in axis-angle vectors which have a norm that is greater than pi
, which seems not good. I’m using a modestly altered version of rotation_conversions.py
for which I show a diff against Meta’s original.
My questions are:
(A) is this behavior expected?
(B) If it is expected, what is the best way to normalize the magnitude of the vector such that they are always between [0,pi]
Thanks in advance!
Instructions To Reproduce the Issue:
- Here is a diff between my
rotation_conversions.py
which I am using to reproduce this behavior, and what is the original from Meta (meta_rotation_conversions.py
).
I created meta_rotation_conversions.py
by copy/pasting the code from here: pytorch3d.transforms.rotation_conversions — PyTorch3D documentation into a file I called tmp.py
and then performing sed s:"\[docs\]":'':g tmp.py > meta_rotation_conversions.py
diff rotation_conversions.py meta_rotation_conversions.py
11a12,13
> from ..common.datatypes import Device
>
51d52
< # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
70a72
>
133,134d134
< # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
< # `int`.
136,137d135
< # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
< # `int`.
139,140d136
< # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
< # `int`.
142,143d137
< # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
< # `int`.
158c152
< F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
---
> F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
161a156
>
220a216
>
305a302
>
307c304
< n: int, dtype: Optional[torch.dtype] = None, device = None
---
> n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
324d320
< # pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
330a327
>
332c329
< n: int, dtype: Optional[torch.dtype] = None, device = None
---
> n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
349a347
>
351c349
< dtype: Optional[torch.dtype] = None, device = None
---
> dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
366a365
>
381a381
>
402a403
>
419a421
>
436a439
>
459a463
>
475a480
>
491a497
>
523a530
>
554a562
>
578a587
>
595c604
< return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
\ No newline at end of file
---
> return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
- Here are the exact commands I ran to produce axis-angle vectors that have norm beyond pi:
import torch
print(torch.__version__)
import rotation_conversions
data = torch.load('ax_ang_outputs.pt', map_location=torch.device('cpu'))
test_R = data['R_pred'][0][0]
print('This is test rotations')
print(test_R)
# get axis angle from matrix
ax_ang_from_R = rotation_conversions.matrix_to_axis_angle(test_R)
# check norm
print(torch.norm(ax_ang_from_R, p=2, dim=-1))
- Here’s what I observed in the output
1.9.0
This is test rotations
tensor([[[ 0.1494, 0.4837, -0.8624],
[ 0.7733, -0.6007, -0.2030],
[-0.6162, -0.6365, -0.4638]],
[[-0.0714, 0.4292, 0.9004],
[ 0.9666, 0.2525, -0.0437],
[-0.2462, 0.8672, -0.4328]],
[[ 0.5917, 0.2904, -0.7520],
[ 0.7764, 0.0455, 0.6285],
[ 0.2167, -0.9558, -0.1986]],
[[ 0.0833, -0.4428, 0.8927],
[-0.8289, -0.5281, -0.1846],
[ 0.5532, -0.7246, -0.4110]],
[[-0.1795, -0.3386, 0.9237],
[-0.7774, 0.6242, 0.0778],
[-0.6029, -0.7041, -0.3752]],
[[ 0.5805, 0.8141, -0.0176],
[ 0.2735, -0.2153, -0.9375],
[-0.7670, 0.5393, -0.3477]],
[[ 0.4837, -0.4856, -0.7282],
[-0.7797, -0.6171, -0.1064],
[-0.3977, 0.6192, -0.6771]],
[[-0.6256, -0.7073, 0.3291],
[-0.5258, 0.6939, 0.4920],
[-0.5763, 0.1348, -0.8060]],
[[-0.2255, 0.6738, 0.7036],
[ 0.4855, 0.7039, -0.5185],
[-0.8446, 0.2247, -0.4859]],
[[ 0.6165, 0.3816, -0.6888],
[-0.3325, -0.6667, -0.6670],
[-0.7137, 0.6402, -0.2841]],
[[-0.1925, -0.8007, -0.5672],
[-0.9634, 0.0443, 0.2645],
[-0.1867, 0.5974, -0.7799]],
[[-0.8547, -0.0775, -0.5133],
[-0.2145, 0.9531, 0.2133],
[ 0.4727, 0.2924, -0.8313]],
[[-0.2660, 0.8619, 0.4316],
[-0.9161, -0.3655, 0.1652],
[ 0.3001, -0.3515, 0.8868]],
[[ 0.6904, 0.5019, -0.5209],
[-0.4371, 0.8633, 0.2525],
[ 0.5764, 0.0534, 0.8154]],
[[-0.4372, -0.8915, -0.1191],
[-0.2856, 0.2632, -0.9215],
[ 0.8528, -0.3688, -0.3696]],
[[ 0.4659, 0.7022, -0.5384],
[ 0.1695, 0.5264, 0.8332],
[ 0.8684, -0.4795, 0.1262]],
[[-0.2217, -0.6468, -0.7297],
[ 0.9226, 0.1031, -0.3716],
[ 0.3156, -0.7557, 0.5739]],
[[-0.9733, -0.1339, 0.1866],
[ 0.0948, -0.9742, -0.2048],
[ 0.2092, -0.1816, 0.9609]],
[[-0.2019, 0.9766, 0.0736],
[-0.3784, -0.1471, 0.9139],
[ 0.9034, 0.1567, 0.3993]],
[[ 0.3122, 0.2442, -0.9181],
[ 0.6225, 0.6774, 0.3919],
[ 0.7177, -0.6939, 0.0595]],
[[-0.6140, -0.6281, -0.4779],
[ 0.7845, -0.4196, -0.4565],
[ 0.0862, -0.6553, 0.7504]],
[[-0.8525, 0.3715, 0.3678],
[-0.2567, -0.9104, 0.3245],
[ 0.4553, 0.1822, 0.8715]],
[[ 0.1279, 0.9249, -0.3580],
[-0.0424, 0.3657, 0.9297],
[ 0.9909, -0.1038, 0.0860]],
[[ 0.1305, -0.2174, -0.9673],
[ 0.9185, 0.3937, 0.0354],
[ 0.3732, -0.8931, 0.2511]],
[[-0.8646, -0.5018, 0.0267],
[ 0.4707, -0.8274, -0.3063],
[ 0.1758, -0.2522, 0.9516]],
[[-0.4554, 0.8043, 0.3817],
[-0.3966, -0.5671, 0.7219],
[ 0.7971, 0.1774, 0.5772]],
[[ 0.3839, 0.5596, -0.7345],
[ 0.3787, 0.6300, 0.6780],
[ 0.8422, -0.5384, 0.0299]],
[[-0.3423, -0.6588, -0.6699],
[ 0.9109, -0.0579, -0.4085],
[ 0.2303, -0.7501, 0.6200]],
[[-0.9885, 0.0321, 0.1475],
[-0.0418, -0.9971, -0.0630],
[ 0.1450, -0.0684, 0.9871]],
[[-0.5158, 0.8466, 0.1310],
[ 0.0590, 0.1877, -0.9805],
[-0.8547, -0.4980, -0.1467]],
[[-0.2505, 0.3370, -0.9076],
[-0.8820, -0.4660, 0.0704],
[-0.3992, 0.8181, 0.4140]],
[[-0.2778, 0.1515, 0.9486],
[-0.1057, 0.9767, -0.1869],
[-0.9548, -0.1522, -0.2553]],
[[ 0.6881, 0.7063, 0.1661],
[ 0.0777, 0.1558, -0.9847],
[-0.7214, 0.6905, 0.0524]],
[[ 0.3823, -0.2206, 0.8973],
[-0.2440, -0.9607, -0.1323],
[ 0.8912, -0.1684, -0.4211]],
[[ 0.5232, 0.6215, 0.5831],
[-0.8507, 0.4215, 0.3141],
[-0.0506, -0.6604, 0.7492]],
[[-0.0445, 0.9967, -0.0677],
[-0.8410, -0.0739, -0.5359],
[-0.5392, 0.0331, 0.8416]],
[[ 0.7902, -0.5468, -0.2768],
[-0.5919, -0.7981, -0.1128],
[-0.1593, 0.2529, -0.9543]],
[[-0.4070, -0.7162, -0.5670],
[-0.7906, 0.5871, -0.1740],
[ 0.4575, 0.3774, -0.8051]],
[[-0.3221, 0.9386, 0.1233],
[ 0.8099, 0.3406, -0.4776],
[-0.4903, -0.0539, -0.8699]],
[[ 0.6649, -0.2304, 0.7105],
[ 0.3991, 0.9137, -0.0771],
[-0.6314, 0.3348, 0.6994]],
[[-0.1578, -0.6089, -0.7774],
[-0.5264, -0.6142, 0.5880],
[-0.8355, 0.5020, -0.2236]],
[[ 0.9653, -0.2548, -0.0574],
[ 0.0103, -0.1827, 0.9831],
[-0.2610, -0.9496, -0.1737]],
[[-0.0468, 0.6846, 0.7274],
[ 0.7944, 0.4669, -0.3884],
[-0.6055, 0.5597, -0.5658]],
[[ 0.8492, -0.5221, -0.0788],
[ 0.0201, 0.1810, -0.9833],
[ 0.5276, 0.8335, 0.1642]],
[[ 0.8291, -0.5556, -0.0633],
[ 0.2867, 0.5196, -0.8048],
[ 0.4800, 0.6491, 0.5901]],
[[-0.6534, -0.1525, -0.7415],
[ 0.6651, 0.3521, -0.6585],
[ 0.3615, -0.9234, -0.1286]],
[[-0.1768, 0.9773, -0.1165],
[-0.4400, 0.0274, 0.8976],
[ 0.8804, 0.2100, 0.4251]],
[[ 0.1008, -0.7674, -0.6332],
[ 0.3219, 0.6273, -0.7091],
[ 0.9414, -0.1323, 0.3103]],
[[-0.8693, 0.1492, 0.4712],
[-0.0189, -0.9627, 0.2700],
[ 0.4939, 0.2258, 0.8397]],
[[-0.3912, -0.3320, -0.8583],
[-0.4360, 0.8882, -0.1448],
[ 0.8104, 0.3176, -0.4922]],
[[-0.9286, -0.1917, -0.3177],
[ 0.3708, -0.4515, -0.8116],
[ 0.0121, -0.8715, 0.4903]],
[[ 0.1223, 0.0621, -0.9906],
[-0.9195, -0.3687, -0.1366],
[-0.3737, 0.9275, 0.0120]],
[[ 0.1735, 0.8650, 0.4709],
[ 0.8893, -0.3431, 0.3025],
[ 0.4232, 0.3662, -0.8287]],
[[ 0.4478, -0.6575, -0.6059],
[ 0.7323, 0.6585, -0.1734],
[ 0.5130, -0.3661, 0.7764]],
[[-0.4742, -0.1710, 0.8636],
[ 0.6836, -0.6896, 0.2389],
[ 0.5547, 0.7037, 0.4440]],
[[-0.1458, 0.5584, -0.8167],
[-0.7840, -0.5686, -0.2489],
[-0.6034, 0.6040, 0.5207]],
[[-0.4864, 0.8186, 0.3055],
[ 0.8697, 0.4870, 0.0800],
[-0.0833, 0.3046, -0.9488]],
[[-0.1426, -0.9897, 0.0134],
[ 0.9831, -0.1432, -0.1144],
[ 0.1152, -0.0031, 0.9933]]], grad_fn=<SelectBackward>)
tensor([3.4340, 2.2470, 4.4279, 3.5236, 2.0547, 2.0843, 2.7027, 2.6237, 2.0987,
2.3012, 3.4104, 3.6644, 4.3308, 0.8168, 3.8307, 1.5115, 1.8466, 3.0259,
4.2176, 1.5462, 2.2674, 3.4726, 1.7825, 1.6834, 2.6265, 3.9048, 1.5489,
1.9715, 3.1787, 2.4000, 4.0031, 1.8528, 1.6227, 3.1633, 1.2165, 4.5735,
2.9470, 3.7639, 2.7536, 0.8776, 3.2080, 4.5156, 2.1808, 1.4734, 1.0822,
3.9158, 4.3418, 1.5516, 3.2293, 4.1915, 2.8080, 2.2360, 3.1000, 1.1137,
2.6060, 4.0728, 2.9137, 1.7175], grad_fn=<NormBackward1>)