Calculation and optimization of multivariate norm for RGB image

Hi,

Given an RGB image (shape of 1,3,H,W) I want to write an optimization procedure to find parameters of pixels’ multivariate norm (i.e., mu (shape of 1,3) and sigma (shape of 3,3)). In addition I wan’t to see the probabilities/density values as an image (shape of 1,1,H,W)

For an err (Image - Mean) and sigma (3x3 matrix) I preform the following:

def get_probs(err, sigma):
        ch, h, w = err.size()[1:]
        err_flat = torch.reshape(err, (h * w, ch))

        inv_sigma = torch.linalg.inv(sigma)

        c1 = 1. / (torch.sqrt((2 * torch.pi) ** (1. / ch) * torch.linalg.det(sigma) + 1e-8))

        mul = torch.matmul(err_flat, inv_sigma)
        mul = torch.einsum('ij,ij->i', mul, err_flat)

        res = c1 * (torch.exp(-0.5 * mul)).view(1, 1, h, w)

        return res

sigma and err are optimized outside of this code, with the loss below:

loss = (w * err).mean()

w is some function of res from above.

The output of this function is a weird duplicates of an image (see below).
But if i’m running this function outside of optimization loop I get a what I expect -density values match to input “error image”.

I also tried to use torch probabilies:

res = torch.exp(MultivariateNormal(torch.zeros_like(err_flat), sigma.repeat(h*w, 1, 1)).log_prob(err_flat))

But result is the same…

What do I do wrong here?

You are interleaving the data in your reshape operation and would need to permute the axes before flattening them.

What is the correct permute order?
Is it only before flatten or also before final reshape?

You would always need to permute the tensor if you want to move dimensions, since reshape or view on non-neighboring dimensions will interleave the tensor as seen here:

n, c, h, w = 1, 2, 3, 3

err = torch.arange(n*c*h*w).view(n, c, h, w)
print(err)
# tensor([[[[ 0,  1,  2],
#           [ 3,  4,  5],
#           [ 6,  7,  8]],

#          [[ 9, 10, 11],
#           [12, 13, 14],
#           [15, 16, 17]]]])

# wrong since it's interleaving the data
err_flat = torch.reshape(err, (n, h*w, c))
print(err_flat)
# tensor([[[ 0,  1],
#          [ 2,  3],
#          [ 4,  5],
#          [ 6,  7],
#          [ 8,  9],
#          [10, 11],
#          [12, 13],
#          [14, 15],
#          [16, 17]]])

# right as the channel dimension is permuted first
y = err.permute(0, 2, 3, 1)
print(y.shape)
# torch.Size([1, 3, 3, 2])

y = y.view(n, h*w, c)
print(y)
# tensor([[[ 0,  9],
#          [ 1, 10],
#          [ 2, 11],
#          [ 3, 12],
#          [ 4, 13],
#          [ 5, 14],
#          [ 6, 15],
#          [ 7, 16],
#          [ 8, 17]]])

Thanks!

The proper flatten procedure is:

bs, ch, h, w = err.size()
err_flat = err.permute(0, 2, 3, 1).contiguous().view(bs, h * w, ch)