AssertionError: Mismatch in batch_size and euler_angles shape

I tried to compute the Euler angles from the quaternion using this code:

def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
    """
    Convert a batch of quaternions to Euler angles.

    Args:
        quaternions: Tensor of shape (batch_size, 4). Batch of quaternions.
        sequence: String specifying the rotation sequence. Default is 'xyz'.

    Returns:
        euler_angles: Tensor of shape (batch_size, 3). Batch of Euler angles.
    """
    batch_size = quaternions.shape[0]
    q = quaternions.detach().cpu().numpy()  # Convert to NumPy array

    rotations = Rotation.from_quat(q)
    euler_angles = rotations.as_euler(sequence, degrees=False)

    euler_angles = torch.tensor(euler_angles, device=quaternions.device)
    euler_angles = euler_angles.view(batch_size, 3)

    return euler_angles

But I got this issue when batch sizes = 1 or 2, or 16:

RuntimeError: shape '[4, 3]' is invalid for input of size 3

I tried to fix it using this code as I found some solution here:

def compute_euler_angles_from_quaternion(quaternions, sequence='xyz'):
    batch_size = quaternions.shape[0]  # Ensure that batch_size is correctly calculated
    print("Batch size:", batch_size)

    q = quaternions.detach().cpu().numpy()  # Convert to NumPy array

    rotations = Rotation.from_quat(q)
    euler_angles = rotations.as_euler(sequence, degrees=False)
    print("Shape of euler_angles:", euler_angles.shape)

    if batch_size == 1:
        euler_angles = euler_angles.reshape(1, -1)  # Reshape single quaternion to (1, 3)

    euler_angles = torch.tensor(euler_angles, device=quaternions.device)

    assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
    euler_angles = euler_angles.view(batch_size, 3)  # Reshape to [batch_size, 3]

    return euler_angles

I got this issue:

Batch size: 4
Shape of euler_angles: (3,)
Traceback (most recent call last):
  File "test_quat.py", line 147, in <module>
    euler = utils.compute_euler_angles_from_quaternion(
  File "/home/redhwan/2/HPE/quat/utils.py", line 318, in compute_euler_angles_from_quaternion
    assert batch_size == euler_angles.shape[0], "Mismatch in batch_size and euler_angles shape"
AssertionError: Mismatch in batch_size and euler_angles shape

I would like the model to work at any batch size.
Please help.