How to map a 3D feature volume into a stack of 2D features using pre-defined slicing planes

I am not sure if this is the right place to ask this technical question. If not, I have also post the same question on stack overflow. Please excuse me and answer the question there, if you believe it would be better.

I wonder how I can build a function map_feat(feat, plane_mask) that can map a 3D feature [B, C, D, H, W] into a stack of 2D feature [B, C, H’, W’] based on a few pre-defined slicing planes. Here is a toy example,

import torch
import numpy as np

def coordinates(voxel_dim, device=torch.device('cpu')):
    nx, ny, nz = voxel_dim
    x = torch.arange(nx, dtype=torch.long, device=device)
    y = torch.arange(ny, dtype=torch.long, device=device)
    z = torch.arange(nz, dtype=torch.long, device=device)
    x, y, z = torch.meshgrid(x, y, z)
    return torch.stack((x.flatten(), y.flatten(), z.flatten()))

def get_plane_mask(verts, n1, n2, n3):
    a =  verts[0] * n1 + verts[1] * n2 + verts[2] * n3
    return a == 1

I have a 3D feature from network. To make things simple, let’s say batch size = 1 and channel = 5

voxel_dim = [10, 10, 5]
feat = torch.randn([1, 5] + voxel_dim)

And I have a few pre-defined slicing planes

verts = coordinates(voxel_dim)
plane_mask = []
plane_mask.append(get_plane_mask(verts, 0.1, 0.2, 0.))
plane_mask.append(get_plane_mask(verts, 0.2, 0, 0))
plane_mask.append(get_plane_mask(verts, 0, 0, 0.5))

Next, I wish to index the feature values lying on these planes and interpolate them into a new 3D feature volume, where the last dimension corresponding to the index of the plane masks, and the first 4 dimensions are the [B, C, H’, W’] as described above, i.e. batch size = 1, #channel = 3, height = 8 and width = 8.

new_feat = torch.zeros([1, 3, 8, 8, len(plane_mask)])
for i in range(len(plane_mask)):
    new_feat[:,:,:,:, i] = map_feat(feat, plane_mask[i])

This mapping is similar to the CT scan, in which, given a slicing plane, we get a 2D image show the cross section of a 3D volume. It also somehow similar to ROI_pooling I feel.

Any idea to implement this map_feat function is appreciated! It would be even better if the implemented function can do it in one pass without for loop. I can save the plane mask into a sparse tensor.

Currently, I am only considering the case where cross_section is a rectangle, no need to deal with more general cases (triangle, pentagon, etc) at this point. Thank you in advance.