Map a tensor at each index based on its neighbors?

The problem I am trying to solve is a transformation on a 3D tensor X. For each entry X[i][j][k], I want to evaluate a function f(X[i][j][k]) that depends on the 8 neighbors (ignoring edge cases) of X[i][j][k], namely X[i±1][j±1][k±1]. Can this be done efficiently in python PyTorch?

Hi Eco!

If I understand your use case properly, yes.

Construct a specific kernel for a Conv3d that plucks out the eight shifted
tensors that you want. So it will be a 3x3x3 kernel with one input channel
and eight output channels with each channel plane containing a single 1.0
in the position that corresponds to the desired neighbor.

When you apply that Conv3d to a (single-channel) tensor, its output will be
an eight-channel tensor whose channels are tensors that have been shifted
so that the desired “neighbor” pixels line up with the corresponding pixels in
the input tensor.

Best.

K. Frank

1 Like