How to slice multiple blocks efficiently from a torch tensor?


(Anand M) #1

I want to slice 3D slices from a tensor of size [N,C,H,W] on basis of some indices. The indices are of the format [B,3] where B is the number of slices and 3 corresponds to N,h0,w0 where N corresponds to the batch of the input and h0, w0 correspond to the corner of the slice needed to be cut.


gathered = input[indices[0][0]:indices[0][0]+1, :, indices[0][1]*kstride[0]:indices[0][1]*kstride[0]+ksize[0], indices[0][2]*kstride[1]: indices[0][2]*kstride[1]+ksize[1]]
for B, h0, w0 in indices[1:]:
	gathered = torch.cat((gathered, input[B:B+1, :, h0*kstride[0]:h0*kstride[0]+ksize[0], w0*kstride[1]: w0*kstride[1]+ksize[1]]), 0)
return gathered

input is the tensor of size [N,C,H,W] where N stands for batch size, C for channels and H,W for width and height. ksize is the size of the block that i want to slice, and kstride is used to upscale the indices to the image size.

This approach is not fast enough for my purposes is there a better way to do this ? I am new to pytorch.


(Jayakrishna Rudra) #2

Can you give some dummy example for the problem?

If I understand it correctly, you have an input of shape (N,C,H,W); get an output of shape (B,N,h0,w0).

Let’s say, (N,h0,w0) is a tuple stored in variable index

N,C,H,W = input.shape
N,h0,w0 = index

input = input.transpose(1,0) # input will now have shape (C,N,H,W)
output = input[B,N,h0,w0]
return output

I assume, slicing input of B,h0,w0 values start from zero. If not, you can add the stride value to it.


(Anand M) #3

A dummy example would be :

Lets us say I have an input tensor of size [64,3,30,30] and I want to cut out slices of shape [3, 5, 5] . where 5 = ksize[0].
Indices are of the form :

indices = [
[0, 1, 2],
[1, 0, 5],
[60, 4, 1]
]

where each element is the index corresponding to N, w0, h0. Now I want to cut a slice of size [C, 5,5] from input using the indices. So I do it in the following way :

slice = input[N, :, w0:w0+5, h0:h0+5]

So, the indices correspond to the top left of the slice that I want to make


(Jayakrishna Rudra) #4

Yeah, this would do!!

change the above line to output = input[B,N,h0:h0+5,w0:w0+5] and again swap the 0,1 axes.