Hey,
I’m trying to slice multiple contiguous regions from a larger tensor, preserving their shape.
Currently I manage to do this using the following function
# volume is (C, D, H, W)
# centers is (N, 3)
def gather_receiptive_fields(volume, centers, ks=3):
L = ks // 2
R = L+1 # slice end indices are exclusive, thus +1
pad_vol = F.pad(volume, tuple([L]*6)) # Pad in case a center is at volume border
return torch.stack([pad_vol[..., # keep channel dims as is
coord[0]-L : coord[0]+R,
coord[1]-L : coord[1]+R,
coord[2]-L : coord[2]+R] for coord in centers + L # Shift by ks//2 to account for padding
])
The result being a tensor of shape (N, C, KS, KS, KS) that contains all N ks^3 bricks around the given centers.
My problem is now that for large N this becomes quite slow, so I wonder whether there is a PyTorch function I could abuse to do such indexing for all N in one operation, hopefully getting around that Python loop there.
Thanks for your answer! This does indeed look very related, unfortunately it only supports 2D image-like tensors at the moment, but I will investigate if I can abuse it for my task as well
One way to speed things up for my case (N ~ 65k with ks = (9,9,9)) is definitely to loop over the 9x9x9 = 729 brick entries (getting N values every iteration) instead of looping over N (getting only 729 entries per iteration).
This change gives me a speed up of around ~40x and gets me quite close to usable speed. Here’s the impl:
def gather_receiptive_fields2(volume, centers, ks=3):
L = ks // 2
offsets = [[L+i,L+j,L+k] for i in range(-L, L+1) for j in range(-L, L+1) for k in range(-L, L+1)]
pad_vol = F.pad(volume, tuple([L]*6))
return torch.stack([pad_vol[...,
centers[:, 0] + off[0],
centers[:, 1] + off[1],
centers[:, 2] + off[2]] for off in offsets
]).permute(2,1,0).reshape(centers.size(0), volume.size(0), ks, ks, ks).contiguous()
I still feel like there should be a faster approach tho