Pytorch implementation of MEFSSIM metric

I’m trying to implement the MEFSSIM metric proposed in 《Perceptual Quality Assessment for
Multi-Exposure Image Fusion》.
However, I find it hard to implement it in an efficient and highly-vectorized way. Does anyone have a better solution? (There are several pytorch implementations of mefssim available on github, but they are all somewhat inconsistent with the matlab implementation provided by the author)
The matlab implementation is as follow:

function [Q, qMap] = mef_ssim(imgSeq, fI,  K, window) 

if (nargin < 2 || nargin > 4) % restrict # of params to be in range [2, 4]
   Q = -Inf;
   qMap = Inf;

if (~exist('K', 'var'))
   K = 0.03;

if (~exist('window', 'var'))
   window = fspecial('gaussian', 11, 1.5); % fspecial('gaussian',hsize,sigma) returns a rotationally symmetric Gaussian lowpass filter of size hsize with standard deviation sigma.

imgSeq = double(imgSeq);
fI = double(fI);
[s1, s2, s3] = size(imgSeq); % s1, s2 is height and width of imgseq, s3 is number of imgs
wSize = size(window,1);
sWindow = ones(wSize) / wSize^2; % square window used to calculate the distance, ones(size) return an all-one matirx of size (size,size)
bd = floor(wSize/2);
mu = zeros(s1-2*bd, s2-2*bd, s3);
ed = zeros(s1-2*bd, s2-2*bd, s3);
for i = 1:s3 % for all input imgs in input sequence, apply mean filter
    img = squeeze(imgSeq(:,:,i));
    % mu is mean value of the patch u_xk
    mu(:,:,i) = filter2(sWindow, img, 'valid'); % valid means no zero-padding, thus the size changes
    muSq = mu(:,:,i) .* mu(:,:,i);
    sigmaSq = filter2(sWindow, img.*img, 'valid') - muSq;
    ed(:,:,i) =  sqrt( max( wSize^2 * sigmaSq, 0 ) ) + 0.001; % add a small constant to avoid instability

R = zeros(s1-2*bd,s2-2*bd); % consistency map which could be used as an output if necessary
for i = bd+1:s1-bd
    for j = bd+1:s2-bd
        vecs = reshape(imgSeq(i-bd:i+bd,j-bd:j+bd,:),[wSize*wSize, s3]);
        denominator = 0;
        for k = 1:s3
            denominator = denominator + norm(vecs(:,k) - mu(i-bd,j-bd,k));
        numerator = norm(sum(vecs,2) - mean(sum(vecs,2)));
        R(i-bd,j-bd) = (numerator + eps) / (denominator + eps); % eq (6)

R(R > 1) = 1 - eps; % get rid of numerical instability
R(R < 0) = 0 + eps;

p = tan(pi/2 * R); % eq(7)
p( p >  10 ) = 10; % to avoid blow up (large number such as 10 is equivalent to taking maximum)
p = repmat(p,[1,1,s3]);

wMap = (ed / wSize).^p + eps; % to avoid blowing up eq(5)
normalizer = sum(wMap,3);
wMap = wMap ./ repmat(normalizer,[1,1,s3]);
maxEd = max(ed,[],3); % eq(3)

C = (K * 255)^2;
qMap = zeros(s1-2*bd, s2-2*bd); 
for i = bd+1:s1-bd
    for j = bd+1:s2-bd
        blocks = imgSeq(i-bd:i+bd,j-bd:j+bd,:);
        rBlock = zeros(wSize,wSize);
        for k = 1 : s3
            rBlock = rBlock  + wMap(i-bd,j-bd,k) * ( blocks(:,:,k) - mu(i-bd,j-bd,k) ) / ed(i-bd,j-bd,k); % eq(4)
        if norm(rBlock(:)) > 0
            rBlock = rBlock / norm(rBlock(:)) * maxEd(i-bd,j-bd);
        fBlock = fI(i-bd:i+bd,j-bd:j+bd);
        rVec = rBlock(:);
        fVec = fBlock(:);
        mu1 = sum( window(:) .* rVec );
        mu2 = sum( window(:) .* fVec );
        sigma1Sq = sum( window(:) .* (rVec - mu1).^2 );
        sigma2Sq = sum( window(:) .* (fVec - mu2).^2 );
        sigma12 = sum(  window(:) .* (rVec - mu1) .* (fVec - mu2)  );
        qMap(i-bd,j-bd) = ( 2 * sigma12 + C ) ./ ( sigma1Sq + sigma2Sq + C ); 

Q = mean2(qMap); % Average or mean of matrix elements

My pytorch implementation:

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math

def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window =
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def mef_ssim(imgSeq, fI, K=0.03):
    fI = fI.squeeze()  # (batch_size, h, w)
    eps = 2.2204e-16
    window = create_window(11, 1).cuda()

    (batch_size, imgSeq_channel, s1, s2) = imgSeq[0].shape
    s3 = len(imgSeq)
    imgSeq_stack = torch.stack(imgSeq).squeeze().unsqueeze(1).cuda().float()  # (img_num, batch_size, h, w)

    wSize = window.shape[2]
    sWindow = torch.ones(window.size()).cuda().float() / wSize ** 2
    bd = wSize // 2
    mu = []
    ed = []
    for i in range(s3):
        img = imgSeq[i].cuda().float()  # (batch_size, channel, h, w)
        mu.append(F.conv2d(img, sWindow, padding = wSize//2, groups = imgSeq_channel))
        muSq = mu[i] * mu[i]
        sigmaSq = F.conv2d(img * img, sWindow, padding = wSize//2, groups = imgSeq_channel) - muSq
        ed.append(torch.sqrt(wSize ** 2 * sigmaSq))
        ed[i][ed[i]<0.001] = 0.001
    ed = torch.stack(ed).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
    mu = torch.stack(mu).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
    x_hat = (imgSeq_stack - mu).float()

    SumFilter = torch.ones(window.size()).cuda().float()
    numerator = torch.zeros((x_hat.size())).cuda()
    denominator = torch.zeros((x_hat.size())).cuda()
    for i in range(s3):
        denominator = denominator + torch.sqrt(F.conv2d(x_hat**2, SumFilter, padding = wSize//2, groups = imgSeq_channel))
    numerator = torch.sqrt(F.conv2d((torch.sum(x_hat,0)**2).unsqueeze(0), SumFilter, padding = wSize//2, groups = imgSeq_channel))
    R = (numerator + eps) / (denominator+eps)

    # R = torch.zeros((batch_size, s1 - 2 * bd, s2 - 2 * bd)).cuda()  # (batch_size, h, w)
    # for batch in range(batch_size):
    #     for i in range(bd, s1 - bd):
    #         for j in range(bd, s2 - bd):
    #             vecs = torch.reshape(imgSeq_stack[:, batch, i - bd: i + bd + 1, j - bd : j + bd + 1], (wSize * wSize, s3))
    #             denominator = 0
    #             for k in range(s3):
    #                 denominator = denominator + torch.norm(vecs[:, k] - mu[k, batch, i - bd - 1, j - bd - 1])
    #             numerator = torch.norm(torch.sum(vecs, 1) - torch.mean(torch.sum(vecs, 1)))
    #             R[batch, i - bd - 1, j - bd - 1] = (numerator + eps) / (denominator + eps)

    R[R > 1] = 1 - eps
    R[R < 0] = 0 + eps

    p = torch.tan( 1.57 * R)
    p[p > 10] = 10
    p_rep = []
    for i in range(s3):
    p = torch.stack(p_rep)  # (img_num, batch_size, h, w)

    wMap = (ed / wSize) ** p + eps  # (img_num, batch_size, h, w)
    normalizer = torch.sum(wMap, 0)
    normalizer_rep = []
    for i in range(s3):
    normalizer_rep = torch.stack(normalizer_rep)
    wMap = wMap / normalizer_rep
    maxEd, maxEd_idx = torch.max(ed, dim=0)

    C = (K * 1) ** 2
    qMap = torch.zeros((x_hat.size())).cuda()

    # UnitFliter = torch.zeros(window.size()).cuda().float()
    # UnitFliter[:,:,5,5] = 1
    # rBlock = torch.sum(F.conv2d(wMap, UnitFliter, padding = wSize//2, groups = imgSeq_channel) *  F.conv2d(x_hat, UnitFliter, padding = wSize//2, groups = imgSeq_channel) /  F.conv2d(ed, UnitFliter, padding = wSize//2, groups = imgSeq_channel), dim=0)
    # if ()
    for batch in range(batch_size):
        for i in range(bd, s1 - bd):
            for j in range(bd, s2 - bd):
                blocks = imgSeq_stack[:, batch, i - bd: i + bd, j-bd : j + bd]
                rBlock = torch.zeros((wSize, wSize))
                for k in range(s3):
                    rBlock = rBlock + wMap[k, batch, i - bd, j - bd] * (
                                blocks[k, :, :, :] - mu[k, batch, i - bd, j - bd]) / ed[k, batch, i - bd, j - bd]

                if (torch.norm(rBlock) > 0):
                    rBlock = rBlock / torch.norm(rBlock) * maxEd[batch, i - bd, j - bd]

                fBlock = fI[batch, i - bd:i + bd, j - bd:j + bd]
                rVec = rBlock.view(-1, 1)
                fVec = fBlock.view(-1, 1)
                windowVec = window.view(-1, 1)
                mu1 = torch.sum(windowVec * rVec)
                mu2 = torch.sum(windowVec * fVec)
                sigma1Sq = torch.sum(windowVec * (rVec - mu1) ** 2)
                sigma2Sq = torch.sum(windowVec * (fVec - mu2) ** 2)
                sigma12 = torch.sum(windowVec * (rVec - mu1) * (fVec - mu2))
                qMap[batch, i - bd, j - bd] = (2 * sigma12 + C) / (sigma1Sq + sigma2Sq + C)

    Q = torch.mean(qMap)

we have a similar implementation in kornia

Thanks for the reply.
But SSIM is actually different from MEFSSIM :sweat_smile: