Slice multiple regions simultaneously

Hey,
I’m trying to slice multiple contiguous regions from a larger tensor, preserving their shape.
Currently I manage to do this using the following function

# volume is (C, D, H, W)
# centers is (N, 3)
def gather_receiptive_fields(volume, centers, ks=3):
    L = ks // 2
    R = L+1  # slice end indices are exclusive, thus +1
    pad_vol = F.pad(volume, tuple([L]*6)) # Pad in case a center is at volume border
    return torch.stack([pad_vol[...,       # keep channel dims as is
            coord[0]-L : coord[0]+R,
            coord[1]-L : coord[1]+R,
            coord[2]-L : coord[2]+R] for coord in centers + L  # Shift by ks//2 to account for padding
    ])

The result being a tensor of shape (N, C, KS, KS, KS) that contains all N ks^3 bricks around the given centers.
My problem is now that for large N this becomes quite slow, so I wonder whether there is a PyTorch function I could abuse to do such indexing for all N in one operation, hopefully getting around that Python loop there.

I’d appreciate any suggestions :slight_smile:
Cheers

From the description it sounds like this is similar to an Unfold operation: Unfold — PyTorch 1.13 documentation

1 Like

Thanks for your answer! This does indeed look very related, unfortunately it only supports 2D image-like tensors at the moment, but I will investigate if I can abuse it for my task as well

One way to speed things up for my case (N ~ 65k with ks = (9,9,9)) is definitely to loop over the 9x9x9 = 729 brick entries (getting N values every iteration) instead of looping over N (getting only 729 entries per iteration).
This change gives me a speed up of around ~40x and gets me quite close to usable speed. Here’s the impl:

def gather_receiptive_fields2(volume, centers, ks=3):
    L = ks // 2
    offsets = [[L+i,L+j,L+k] for i in range(-L, L+1) for j in range(-L, L+1) for k in range(-L, L+1)]
    pad_vol = F.pad(volume, tuple([L]*6))
    return torch.stack([pad_vol[...,
            centers[:, 0] + off[0],
            centers[:, 1] + off[1],
            centers[:, 2] + off[2]] for off in offsets
    ]).permute(2,1,0).reshape(centers.size(0), volume.size(0), ks, ks, ks).contiguous()

I still feel like there should be a faster approach tho

I would also take a look at external repos that have implemented an “Nd” Unfold e.g., GitHub - f-dangel/unfoldNd: (N=1,2,3)-dimensional unfold (im2col) and fold (col2im) in PyTorch

1 Like