I have a tensor `mxm`

and I want to roll each row according to a specific, computed amount (for an easy example: `(row_index // 3)`

). I have tried using `torch.roll`

, but it appears to roll only across a complete axis and not along a specific dimension.

Any ideas?

Do you need to roll every row (as per your post) or some rows (as per your title?)

If every row:

```
dim = 0 # all the rows
output = list(map(torch.roll, torch.unbind(mxm, dim), list_of_computed_amounts))
output = torch.stack(output, dim)
```

Where `torch.unbind`

returns a tuple of all slices along a given dimension. You apply the torch.roll on each of these rows,

`computed_amount`

should be your shifts, equaling the number of rows.

`torch.stack`

just puts it back into one tensor.

Otherwise if only some rows, loop through the indices and computed amounts with torch.roll and set that row to the new rolled values.

Obviously not super efficient, but something that might work.

2 Likes

Sry for being vague. I need to roll every row.