# Calculation and optimization of multivariate norm for RGB image

Hi,

Given an RGB image (shape of 1,3,H,W) I want to write an optimization procedure to find parameters of pixels’ multivariate norm (i.e., mu (shape of 1,3) and sigma (shape of 3,3)). In addition I wan’t to see the probabilities/density values as an image (shape of 1,1,H,W)

For an err (Image - Mean) and sigma (3x3 matrix) I preform the following:

``````def get_probs(err, sigma):
ch, h, w = err.size()[1:]
err_flat = torch.reshape(err, (h * w, ch))

inv_sigma = torch.linalg.inv(sigma)

c1 = 1. / (torch.sqrt((2 * torch.pi) ** (1. / ch) * torch.linalg.det(sigma) + 1e-8))

mul = torch.matmul(err_flat, inv_sigma)
mul = torch.einsum('ij,ij->i', mul, err_flat)

res = c1 * (torch.exp(-0.5 * mul)).view(1, 1, h, w)

return res
``````

sigma and err are optimized outside of this code, with the loss below:

``````loss = (w * err).mean()
``````

w is some function of res from above.

The output of this function is a weird duplicates of an image (see below).
But if i’m running this function outside of optimization loop I get a what I expect -density values match to input “error image”.

I also tried to use torch probabilies:

``````res = torch.exp(MultivariateNormal(torch.zeros_like(err_flat), sigma.repeat(h*w, 1, 1)).log_prob(err_flat))
``````

But result is the same…

What do I do wrong here?

You are interleaving the data in your reshape operation and would need to `permute` the axes before flattening them.

What is the correct permute order?
Is it only before flatten or also before final reshape?

You would always need to `permute` the tensor if you want to move dimensions, since `reshape` or `view` on non-neighboring dimensions will interleave the tensor as seen here:

``````n, c, h, w = 1, 2, 3, 3

err = torch.arange(n*c*h*w).view(n, c, h, w)
print(err)
# tensor([[[[ 0,  1,  2],
#           [ 3,  4,  5],
#           [ 6,  7,  8]],

#          [[ 9, 10, 11],
#           [12, 13, 14],
#           [15, 16, 17]]]])

# wrong since it's interleaving the data
err_flat = torch.reshape(err, (n, h*w, c))
print(err_flat)
# tensor([[[ 0,  1],
#          [ 2,  3],
#          [ 4,  5],
#          [ 6,  7],
#          [ 8,  9],
#          [10, 11],
#          [12, 13],
#          [14, 15],
#          [16, 17]]])

# right as the channel dimension is permuted first
y = err.permute(0, 2, 3, 1)
print(y.shape)
# torch.Size([1, 3, 3, 2])

y = y.view(n, h*w, c)
print(y)
# tensor([[[ 0,  9],
#          [ 1, 10],
#          [ 2, 11],
#          [ 3, 12],
#          [ 4, 13],
#          [ 5, 14],
#          [ 6, 15],
#          [ 7, 16],
#          [ 8, 17]]])
``````

Thanks!

The proper flatten procedure is:

``````bs, ch, h, w = err.size()
err_flat = err.permute(0, 2, 3, 1).contiguous().view(bs, h * w, ch)
``````