The problem I am trying to solve is a transformation on a 3D tensor X. For each entry X[i][j][k], I want to evaluate a function f(X[i][j][k]) that depends on the 8 neighbors (ignoring edge cases) of X[i][j][k], namely X[i±1][j±1][k±1]. Can this be done efficiently in python PyTorch?
Hi Eco!
If I understand your use case properly, yes.
Construct a specific kernel for a Conv3d
that plucks out the eight shifted
tensors that you want. So it will be a 3x3x3
kernel with one input channel
and eight output channels with each channel plane containing a single 1.0
in the position that corresponds to the desired neighbor.
When you apply that Conv3d
to a (single-channel) tensor, its output will be
an eight-channel tensor whose channels are tensors that have been shifted
so that the desired “neighbor” pixels line up with the corresponding pixels in
the input tensor.
Best.
K. Frank
1 Like