Do you need to roll every row (as per your post) or some rows (as per your title?)
If every row:
dim = 0 # all the rows
output = list(map(torch.roll, torch.unbind(mxm, dim), list_of_computed_amounts))
output = torch.stack(output, dim)
Where torch.unbind
returns a tuple of all slices along a given dimension. You apply the torch.roll on each of these rows,
computed_amount
should be your shifts, equaling the number of rows.
torch.stack
just puts it back into one tensor.
Otherwise if only some rows, loop through the indices and computed amounts with torch.roll and set that row to the new rolled values.
Obviously not super efficient, but something that might work.