I want to slice 3D slices from a tensor of size [N,C,H,W] on basis of some indices. The indices are of the format [B,3] where B is the number of slices and 3 corresponds to N,h0,w0 where N corresponds to the batch of the input and h0, w0 correspond to the corner of the slice needed to be cut.

```
gathered = input[indices[0][0]:indices[0][0]+1, :, indices[0][1]*kstride[0]:indices[0][1]*kstride[0]+ksize[0], indices[0][2]*kstride[1]: indices[0][2]*kstride[1]+ksize[1]]
for B, h0, w0 in indices[1:]:
gathered = torch.cat((gathered, input[B:B+1, :, h0*kstride[0]:h0*kstride[0]+ksize[0], w0*kstride[1]: w0*kstride[1]+ksize[1]]), 0)
return gathered
```

input is the tensor of size [N,C,H,W] where N stands for batch size, C for channels and H,W for width and height. ksize is the size of the block that i want to slice, and kstride is used to upscale the indices to the image size.

This approach is not fast enough for my purposes is there a better way to do this ? I am new to pytorch.