Pytorch implementation of MEFSSIM metric

Hi.
I’m trying to implement the MEFSSIM metric proposed in 《Perceptual Quality Assessment for
Multi-Exposure Image Fusion》.
However, I find it hard to implement it in an efficient and highly-vectorized way. Does anyone have a better solution? (There are several pytorch implementations of mefssim available on github, but they are all somewhat inconsistent with the matlab implementation provided by the author)
The matlab implementation is as follow:

``````function [Q, qMap] = mef_ssim(imgSeq, fI,  K, window)

if (nargin < 2 || nargin > 4) % restrict # of params to be in range [2, 4]
Q = -Inf;
qMap = Inf;
return;
end

if (~exist('K', 'var'))
K = 0.03;
end

if (~exist('window', 'var'))
window = fspecial('gaussian', 11, 1.5); % fspecial('gaussian',hsize,sigma) returns a rotationally symmetric Gaussian lowpass filter of size hsize with standard deviation sigma.
end

imgSeq = double(imgSeq);
fI = double(fI);
[s1, s2, s3] = size(imgSeq); % s1, s2 is height and width of imgseq, s3 is number of imgs
wSize = size(window,1);
sWindow = ones(wSize) / wSize^2; % square window used to calculate the distance, ones(size) return an all-one matirx of size (size,size)
bd = floor(wSize/2);
mu = zeros(s1-2*bd, s2-2*bd, s3);
ed = zeros(s1-2*bd, s2-2*bd, s3);
for i = 1:s3 % for all input imgs in input sequence, apply mean filter
img = squeeze(imgSeq(:,:,i));
% mu is mean value of the patch u_xk
mu(:,:,i) = filter2(sWindow, img, 'valid'); % valid means no zero-padding, thus the size changes
muSq = mu(:,:,i) .* mu(:,:,i);
sigmaSq = filter2(sWindow, img.*img, 'valid') - muSq;
ed(:,:,i) =  sqrt( max( wSize^2 * sigmaSq, 0 ) ) + 0.001; % add a small constant to avoid instability
end

R = zeros(s1-2*bd,s2-2*bd); % consistency map which could be used as an output if necessary
for i = bd+1:s1-bd
for j = bd+1:s2-bd
vecs = reshape(imgSeq(i-bd:i+bd,j-bd:j+bd,:),[wSize*wSize, s3]);
denominator = 0;
for k = 1:s3
denominator = denominator + norm(vecs(:,k) - mu(i-bd,j-bd,k));
end
numerator = norm(sum(vecs,2) - mean(sum(vecs,2)));
R(i-bd,j-bd) = (numerator + eps) / (denominator + eps); % eq (6)
end
end

R(R > 1) = 1 - eps; % get rid of numerical instability
R(R < 0) = 0 + eps;

p = tan(pi/2 * R); % eq(7)
p( p >  10 ) = 10; % to avoid blow up (large number such as 10 is equivalent to taking maximum)
p = repmat(p,[1,1,s3]);

wMap = (ed / wSize).^p + eps; % to avoid blowing up eq(5)
normalizer = sum(wMap,3);
wMap = wMap ./ repmat(normalizer,[1,1,s3]);
maxEd = max(ed,[],3); % eq(3)

C = (K * 255)^2;
qMap = zeros(s1-2*bd, s2-2*bd);
for i = bd+1:s1-bd
for j = bd+1:s2-bd
blocks = imgSeq(i-bd:i+bd,j-bd:j+bd,:);
rBlock = zeros(wSize,wSize);
for k = 1 : s3
rBlock = rBlock  + wMap(i-bd,j-bd,k) * ( blocks(:,:,k) - mu(i-bd,j-bd,k) ) / ed(i-bd,j-bd,k); % eq(4)
end
if norm(rBlock(:)) > 0
rBlock = rBlock / norm(rBlock(:)) * maxEd(i-bd,j-bd);
end
fBlock = fI(i-bd:i+bd,j-bd:j+bd);
rVec = rBlock(:);
fVec = fBlock(:);
mu1 = sum( window(:) .* rVec );
mu2 = sum( window(:) .* fVec );
sigma1Sq = sum( window(:) .* (rVec - mu1).^2 );
sigma2Sq = sum( window(:) .* (fVec - mu2).^2 );
sigma12 = sum(  window(:) .* (rVec - mu1) .* (fVec - mu2)  );
qMap(i-bd,j-bd) = ( 2 * sigma12 + C ) ./ ( sigma1Sq + sigma2Sq + C );
end
end

Q = mean2(qMap); % Average or mean of matrix elements
``````

My pytorch implementation:

``````import torch
import torch.nn.functional as F
import numpy as np
import math

def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()

def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window

def mef_ssim(imgSeq, fI, K=0.03):
fI = fI.squeeze()  # (batch_size, h, w)
eps = 2.2204e-16
window = create_window(11, 1).cuda()

(batch_size, imgSeq_channel, s1, s2) = imgSeq[0].shape
s3 = len(imgSeq)
imgSeq_stack = torch.stack(imgSeq).squeeze().unsqueeze(1).cuda().float()  # (img_num, batch_size, h, w)

wSize = window.shape[2]
sWindow = torch.ones(window.size()).cuda().float() / wSize ** 2
bd = wSize // 2
mu = []
ed = []
for i in range(s3):
img = imgSeq[i].cuda().float()  # (batch_size, channel, h, w)
mu.append(F.conv2d(img, sWindow, padding = wSize//2, groups = imgSeq_channel))
muSq = mu[i] * mu[i]
sigmaSq = F.conv2d(img * img, sWindow, padding = wSize//2, groups = imgSeq_channel) - muSq
ed.append(torch.sqrt(wSize ** 2 * sigmaSq))
ed[i][ed[i]<0.001] = 0.001
ed = torch.stack(ed).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
mu = torch.stack(mu).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
x_hat = (imgSeq_stack - mu).float()

SumFilter = torch.ones(window.size()).cuda().float()
numerator = torch.zeros((x_hat.size())).cuda()
denominator = torch.zeros((x_hat.size())).cuda()
for i in range(s3):
denominator = denominator + torch.sqrt(F.conv2d(x_hat**2, SumFilter, padding = wSize//2, groups = imgSeq_channel))
numerator = torch.sqrt(F.conv2d((torch.sum(x_hat,0)**2).unsqueeze(0), SumFilter, padding = wSize//2, groups = imgSeq_channel))
R = (numerator + eps) / (denominator+eps)

# R = torch.zeros((batch_size, s1 - 2 * bd, s2 - 2 * bd)).cuda()  # (batch_size, h, w)
# for batch in range(batch_size):
#     for i in range(bd, s1 - bd):
#         for j in range(bd, s2 - bd):
#             vecs = torch.reshape(imgSeq_stack[:, batch, i - bd: i + bd + 1, j - bd : j + bd + 1], (wSize * wSize, s3))
#             denominator = 0
#             for k in range(s3):
#                 denominator = denominator + torch.norm(vecs[:, k] - mu[k, batch, i - bd - 1, j - bd - 1])
#             numerator = torch.norm(torch.sum(vecs, 1) - torch.mean(torch.sum(vecs, 1)))
#             R[batch, i - bd - 1, j - bd - 1] = (numerator + eps) / (denominator + eps)

R[R > 1] = 1 - eps
R[R < 0] = 0 + eps

p = torch.tan( 1.57 * R)
p[p > 10] = 10
p_rep = []
for i in range(s3):
p_rep.append(p)
p = torch.stack(p_rep)  # (img_num, batch_size, h, w)

wMap = (ed / wSize) ** p + eps  # (img_num, batch_size, h, w)
normalizer = torch.sum(wMap, 0)
normalizer_rep = []
for i in range(s3):
normalizer_rep.append(normalizer)
normalizer_rep = torch.stack(normalizer_rep)
wMap = wMap / normalizer_rep
maxEd, maxEd_idx = torch.max(ed, dim=0)

C = (K * 1) ** 2
qMap = torch.zeros((x_hat.size())).cuda()

# UnitFliter = torch.zeros(window.size()).cuda().float()
# UnitFliter[:,:,5,5] = 1
# rBlock = torch.sum(F.conv2d(wMap, UnitFliter, padding = wSize//2, groups = imgSeq_channel) *  F.conv2d(x_hat, UnitFliter, padding = wSize//2, groups = imgSeq_channel) /  F.conv2d(ed, UnitFliter, padding = wSize//2, groups = imgSeq_channel), dim=0)
# if ()
for batch in range(batch_size):
for i in range(bd, s1 - bd):
for j in range(bd, s2 - bd):
blocks = imgSeq_stack[:, batch, i - bd: i + bd, j-bd : j + bd]
rBlock = torch.zeros((wSize, wSize))
for k in range(s3):
rBlock = rBlock + wMap[k, batch, i - bd, j - bd] * (
blocks[k, :, :, :] - mu[k, batch, i - bd, j - bd]) / ed[k, batch, i - bd, j - bd]

if (torch.norm(rBlock) > 0):
rBlock = rBlock / torch.norm(rBlock) * maxEd[batch, i - bd, j - bd]

fBlock = fI[batch, i - bd:i + bd, j - bd:j + bd]
rVec = rBlock.view(-1, 1)
fVec = fBlock.view(-1, 1)
windowVec = window.view(-1, 1)
mu1 = torch.sum(windowVec * rVec)
mu2 = torch.sum(windowVec * fVec)
sigma1Sq = torch.sum(windowVec * (rVec - mu1) ** 2)
sigma2Sq = torch.sum(windowVec * (fVec - mu2) ** 2)
sigma12 = torch.sum(windowVec * (rVec - mu1) * (fVec - mu2))
qMap[batch, i - bd, j - bd] = (2 * sigma12 + C) / (sigma1Sq + sigma2Sq + C)

Q = torch.mean(qMap)
``````

we have a similar implementation in kornia