Multiple patches from a tensor

I am working with 3d images. If I have a single image tensor of size d x w x h, I would like to extract many (~100) same-sized (e.g. 32 x 32 x 32) 3d patches defined by a tensor of size (n_patches x 3) corresponding to the minimum value (in each of 3 dimensions) of each patch.

Ideally, the output would be a n_patches x 32 x 32 x 32 tensor containing all of the extracted patches.

Is there a more elegant/quick way to do this indexing without looping n_patches times and concatenating the patches?

A toy example in 2D:

say image tensor X is

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])

and the patch array is

tensor([[1, 6],
        [2, 0],
        [2, 8]])

then with patches of size 2,
the result would be

tensor([[[16, 17],
         [18, 19]],

        [[20, 21],
         [30, 31]],

        [[28, 29],
         [38, 39]]])

torch.nn.Unfold may be useful.