Hi all,
I have this function
def rotation_matrices(angles, axis):
"""Compute rotation matrices from angle/axis representations.
Parameters
----------
angles : (n,) float
The angles.
axis : (3,) float
The axis.
Returns
-------
rots : (n,4,4)
The rotation matrices
"""
axis = axis / torch.norm(axis)
sina = torch.sin(angles)
cosa = torch.cos(angles)
M = torch.eye(4, device=angles.device).repeat((len(angles), 1, 1))
M[:, 0, 0] = cosa
M[:, 1, 1] = cosa
M[:, 2, 2] = cosa
M[:, :3, :3] += (
torch.ger(axis, axis).repeat((len(angles), 1, 1))
* (1.0 - cosa)[:, np.newaxis, np.newaxis]
)
M[:, :3, :3] += (
torch.tensor(
[
[0.0, -axis[2], axis[1]],
[axis[2], 0.0, -axis[0]],
[-axis[1], axis[0], 0.0],
],
device=angles.device,
).repeat((len(angles), 1, 1))
* sina[:, np.newaxis, np.newaxis]
)
return M
I’d like to be able to call this within vmap like so
angles = torch.ones(10)
axis = torch.tensor([0., 1., 0.])
rotation_matrices(angles, axis) # Works fine
batched = torch.vmap(rotation_matrices)
batched(angles[None, ...], axis[None, ...]) # Fails
I get the following errors:
18 in _rotation_matrices(angles, axis)
16 cosa = torch.cos(angles)
17 M = torch.eye(4, device=angles.device).repeat((len(angles), 1, 1))
---> 18 M[:, 0, 0] = cosa
19 M[:, 1, 1] = cosa
20 M[:, 2, 2] = cosa
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
How should I change this function to work in vmap appropriately?
Thanks!