Fill 0 values in a tensor with nearest neighbors

Hi I have a tensor of size (1,320,1216), which has values = 0 at most of the places. I want to replace those values with their nearest neighbors. What is the most efficient and fast way of achieving this considering this would happen in online training?

I don’t find that really straight forward.
First of all you would need to calculate which is the nearest neightbor.
You can get non-zero elements by doing tensor[tensor>0].
This is already bit expensive.
But this way you are losing the position in the grid.
Getting that is hard.

In fact you could try to use sparse tensors.
https://pytorch.org/docs/stable/sparse.html
as they save the position of the non-zero elements.

However as u can see here:

The developer point out that there is no easy way to obtain a sparse tensor from a dense one. Which is basically the problem you are trying to solve.

If you can construct the sparse tensor you can achieve good performance.
Otherwise I would recomment to find another way to fill the non-zero elements or to relax your criteria bout what’s the nearest.

For example something very silly that would be fast is to run a convolution (as the operation takes the max of the kernel). If you matrix is really sparse and ur kernel is big enough you would me taking some sort of nearest neighbor.

hope it helps

I agree with most of the points that you have made, however, my problem is to create dense tensors from sparse tensors, not the other way around. I have a tensor that contains values at only 25% of the total indices e.g if total values are 100 then only 25 values would be present and 75 would be 0. I want to interpolate that 0 values with their nearest values in the index. I thought there would be a clean way to simply fill a tensor with 0 values.

Oh btw I remembered that there is a function which return those indices:

import torch

with torch.no_grad():
    tensor = torch.rand(4, 4)
    # generate zero elements
    tensor[tensor < 0.5] = 0
    print(tensor)
    # Obtain position of non-zero elements
    print(tensor.nonzero())
tensor([[0.0000, 0.7300, 0.7180, 0.0000],
        [0.0000, 0.7008, 0.5255, 0.0000],
        [0.6798, 0.0000, 0.0000, 0.6164],
        [0.9133, 0.0000, 0.7820, 0.0000]])
tensor([[0, 1],
        [0, 2],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 3],
        [3, 0],
        [3, 2]])

Soo given this you can easily apply the interpolation

thanks, this is much cleaner than the initial solution I thought.