Implementation of function like `numpy.roll`

I am trying to implement a function similar to numpy.roll function. To use cuda, I want to use only torch Tensors, but it seems to be hard.

Ultimately, I want to implement image gradient with forward difference and Neumann boundary condition. For example, the numpy version of it is as follows:

def grad(u):
    # u: 2-d images
    ux = np.roll(u, -1, axis=1) - u
    uy = np.roll(u, -1, axis=0) - u
    ux[:,-1] = 0
    uy[-1,:] = 0

I tried to use [-1, 1] filter, using torch.nn.Functional.conv2d. Because there is no boundary option except zero padding and the filter is even filter, it looks complicated.

Is there someone who can help implement one like np.roll?

4 Likes

Thanks for your reply, although I checked before. As you said, it seems hard to implement.

If should be quite simple to implement yourself. Just slice the tensor into two pieces, swap them, and cat along the same dimension that you used to split.

1 Like

i made an account to say that i think thats a bit yucky, it’s nicer to read roll :frowning:

1 Like

haven’t tested this extensively, but this seems to cover if you just want a single split. logic if shift is negative could probably be cleaned up a little

def roll(tensor, shift, axis):
    if shift == 0:
        return tensor

    if axis < 0:
        axis += tensor.dim()

    dim_size = tensor.size(axis)
    after_start = dim_size - shift
    if shift < 0:
        after_start = -shift
        shift = dim_size - abs(shift)

    before = tensor.narrow(axis, 0, dim_size - shift)
    after = tensor.narrow(axis, after_start, shift)
    return torch.cat([after, before], axis)
1 Like

Simple solution to roll around first axis:

def roll(x, n):  
    return torch.cat((x[-n:], x[:-n]))

Test like:

x = torch.arange(5)
print("Orig:", x)
print("Roll 2:", roll(x, 2))
print("Roll -2:", roll(x, -2))

Outputs:

Orig: tensor([0, 1, 2, 3, 4])
Roll 2: tensor([3, 4, 0, 1, 2])
Roll -2: tensor([2, 3, 4, 0, 1])

To roll around second axis, use:

def roll_1(x, n):  
    return torch.cat((x[:, -n:], x[:, :-n]), dim=1)

It probably can be generalised, but I didn’t need it.

7 Likes

@jaromiru thank you!
The solution below is a generalization of yours to an arbitrary axis:

def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[int] = None):

    if 0 == shift:
        return x

    elif shift < 0:
        shift = -shift
        gap = x.index_select(dim, torch.arange(shift))
        if fill_pad is not None:
            gap = fill_pad * torch.ones_like(gap, device=x.device)
        return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim))), gap], dim=dim)

    else:
        shift = x.size(dim) - shift
        gap = x.index_select(dim, torch.arange(shift, x.size(dim)))
        if fill_pad is not None:
            gap = fill_pad * torch.ones_like(gap, device=x.device)
        return torch.cat([gap, x.index_select(dim, torch.arange(shift))], dim=dim)
2 Likes

I tried to use yours, but I get a compilation error saying that Optional is not defined.

@Zuanazzi,

from typing import Optional
1 Like

Extending the solution to support devices :slight_smile:

from typing import Optional

def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[int] = None):

    device = x.device
    
    if 0 == shift:
        return x

    elif shift < 0:
        shift = -shift
        gap = x.index_select(dim, torch.arange(shift, device=device))
        if fill_pad is not None:
            gap = fill_pad * torch.ones_like(gap, device=device)
        return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim), device=device)), gap], dim=dim)

    else:
        shift = x.size(dim) - shift
        gap = x.index_select(dim, torch.arange(shift, x.size(dim), device=device))
        if fill_pad is not None:
            gap = fill_pad * torch.ones_like(gap, device=device)
        return torch.cat([gap, x.index_select(dim, torch.arange(shift, device=device))], dim=dim) 
1 Like

Since this is still getting answers in 2020 and is the top Google answer, it’s worth pointing out that there is now a proper torch.roll function.

7 Likes