Efficient kernel density estimation

i have an rgb image.
is there an efficient way to compute a kde over the colors?

assume that we have a maximum of 256 colors, and that we consider discretize them into 128 bins.
this will form a 3d histogram of a color space: 128**3, 3.

i made a naive implementation but it is impractical in term of memory.

ndim = 3  # rgb

 def _get_color_space() -> torch.Tensor:
        x = torch.linspace(start=0., end=256, steps=128, dtype=torch.float32,
        tensors = [x for _ in range(ndim)]
        return torch.cartesian_prod(*tensors)

color_space = _get_color_space()
dim, h, w = img.shape  # dim = ndim = 3
x = img.contiguous().view(h * w, dim)

# op1
ki = (x.unsqueeze(1) - color_space)**2  # h*w, nbin**ndim, ndim  //MEMORY ISSUE.
# op2
ki = ki.sum(dim=-1) / const2  # h*w, nbin**ndim   //MEMORY ISSUE.
# op3
ki = const1 * torch.exp(-ki)  # h*w, nbin**ndim   //MEMORY ISSUE.
# op4
prob = ki.mean(dim=0)  # 1, nbin**ndim

the issue is that we store in memory the difference between ALL pixels and ALL kernels.
it is unnecessary.

another approach is to LOOP over kernels to avoid storing all results in memory. something like this:

prob = torch.zeros((1, nbin**ndim))
for kernel_i in range(color_space.shape[0]):
      prob[kernel_i] = op_1_2_3_4(img, kernel_i)  # scalar
# where op_1_2_3_4(img, kernel_i) is simply
# ~ (exp((img - v)**2).sum(dim=0)).mean() 
# where v is a vector of 3 components.

the issue with this is the loop over a large number that is 128**3 ~ 2millions.
i wonder if there is an efficient way to do it.
i can minibatch it (in addition to looping over each image in minibatch…). but it is still expensive.
this is similar to convolution.
but we do not need to store the ‘feature maps’ that is the result of exp((img - v)**2).sum(dim=0).
the result of each convolution, that is a feature map, should be averaged and stored.
we dont need to have any feature map in memory (just instantly, then it is averaged).

is it possible to design an operation that is similar to 1d convolution, but reduces the feature map right away via mean? == exp((img - v)**2).sum(dim=0).mean()

any ideas?

in tensorflow, they have a kde dist. https://arxiv.org/pdf/1711.10604.pdf