Hi there,
I have a small function that performs cropping/padding in fourier space. Input is a tensor of which the last two (image) or alst three (volume) dimensions are results from an fft.rfftn
call. Therefore I needed multiple index_put_
operations to correctly fill the cropped/padded FFT:
using namespace torch::indexing;
torch::Tensor fft_crop(torch::Tensor& fft_volume, int dim, int new_x, int new_y, int new_z) {
auto n = fft_volume.ndimension();
auto newDims = fft_volume.sizes().vec();
newDims[n - 1] = (int64_t)new_x / 2 + 1;
if (dim >= 2)
newDims[n - 2] = new_y;
if (dim == 3)
newDims[n - 3] = new_z;
auto newVol = torch::empty(newDims, fft_volume.options());
int old_x = (fft_volume.size(-1)-1)*2;
int x = std::min(old_x, new_x);
int old_y = fft_volume.size(-2);
int y = std::min(old_y, new_y);
if (dim == 2) {
newVol.index_put_({ "...", Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) },
fft_volume.index({ "...",Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) }));
newVol.index_put_({ "...", Slice(new_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) },
fft_volume.index({ "...", Slice(old_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) }));
}
if (dim == 3) {
int old_z = fft_volume.size(-3);
int z = std::min(old_z, new_z);
newVol.index_put_({ "...", Slice(0, z / 2 + 1), Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) },
fft_volume.index({ "...", Slice(0, z / 2 + 1), Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) }));
newVol.index_put_({ "...", Slice(0, z / 2 + 1), Slice(new_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) },
fft_volume.index({ "...", Slice(0, z / 2 + 1), Slice(old_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) }));
newVol.index_put_({ "...", Slice(new_z - z + z / 2 + 1, None), Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) },
fft_volume.index({ "...", Slice(old_z - z + z / 2 + 1, None), Slice(0, y / 2 + 1), Slice(0, x / 2 + 1) }));
newVol.index_put_({ "...", Slice(new_z - z + z / 2 + 1, None), Slice(new_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) },
fft_volume.index({ "...", Slice(old_z - z + z / 2 + 1, None), Slice(old_y - y + (y / 2 + 1), None), Slice(0, x / 2 + 1) }));
}
return newVol;
}
I am now wondering about the efficiency of this code. Basically, I am just copying different parts of the fft_volume
into my new_volume
. Is there any way to increase performance here, i.e. combining this into a single operation?