How to roll some rows of a matrix

I have a tensor mxm and I want to roll each row according to a specific, computed amount (for an easy example: (row_index // 3)). I have tried using torch.roll, but it appears to roll only across a complete axis and not along a specific dimension.
Any ideas?

Do you need to roll every row (as per your post) or some rows (as per your title?)
If every row:

dim = 0 # all the rows
output = list(map(torch.roll, torch.unbind(mxm, dim), list_of_computed_amounts))
output = torch.stack(output, dim)

Where torch.unbind returns a tuple of all slices along a given dimension. You apply the torch.roll on each of these rows,
computed_amount should be your shifts, equaling the number of rows.
torch.stack just puts it back into one tensor.

Otherwise if only some rows, loop through the indices and computed amounts with torch.roll and set that row to the new rolled values.

Obviously not super efficient, but something that might work.


Sry for being vague. I need to roll every row.