How to write VMap function correctly

Hi all,

I have this function

def rotation_matrices(angles, axis):
    """Compute rotation matrices from angle/axis representations.
    Parameters
    ----------
    angles : (n,) float
        The angles.
    axis : (3,) float
        The axis.
    Returns
    -------
    rots : (n,4,4)
        The rotation matrices
    """
    axis = axis / torch.norm(axis)
    sina = torch.sin(angles)
    cosa = torch.cos(angles)
    M = torch.eye(4, device=angles.device).repeat((len(angles), 1, 1))
    M[:, 0, 0] = cosa
    M[:, 1, 1] = cosa
    M[:, 2, 2] = cosa
    M[:, :3, :3] += (
        torch.ger(axis, axis).repeat((len(angles), 1, 1))
        * (1.0 - cosa)[:, np.newaxis, np.newaxis]
    )
    M[:, :3, :3] += (
        torch.tensor(
            [
                [0.0, -axis[2], axis[1]],
                [axis[2], 0.0, -axis[0]],
                [-axis[1], axis[0], 0.0],
            ],
            device=angles.device,
        ).repeat((len(angles), 1, 1))
        * sina[:, np.newaxis, np.newaxis]
    )
    return M

I’d like to be able to call this within vmap like so

angles = torch.ones(10)
axis = torch.tensor([0., 1., 0.])
rotation_matrices(angles, axis)  # Works fine

batched = torch.vmap(rotation_matrices)
batched(angles[None, ...], axis[None, ...])  # Fails

I get the following errors:

18 in _rotation_matrices(angles, axis)
     16 cosa = torch.cos(angles)
     17 M = torch.eye(4, device=angles.device).repeat((len(angles), 1, 1))
---> 18 M[:, 0, 0] = cosa
     19 M[:, 1, 1] = cosa
     20 M[:, 2, 2] = cosa

RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.

How should I change this function to work in vmap appropriately?

Thanks!