How to roll some rows of a matrix

Do you need to roll every row (as per your post) or some rows (as per your title?)
If every row:

dim = 0 # all the rows
output = list(map(torch.roll, torch.unbind(mxm, dim), list_of_computed_amounts))
output = torch.stack(output, dim)

Where torch.unbind returns a tuple of all slices along a given dimension. You apply the torch.roll on each of these rows,
computed_amount should be your shifts, equaling the number of rows.
torch.stack just puts it back into one tensor.

Otherwise if only some rows, loop through the indices and computed amounts with torch.roll and set that row to the new rolled values.

Obviously not super efficient, but something that might work.

3 Likes