Pytorch implementation of MEFSSIM metric

Hi.
I’m trying to implement the MEFSSIM metric proposed in 《Perceptual Quality Assessment for
Multi-Exposure Image Fusion》.
However, I find it hard to implement it in an efficient and highly-vectorized way. Does anyone have a better solution? (There are several pytorch implementations of mefssim available on github, but they are all somewhat inconsistent with the matlab implementation provided by the author)
The matlab implementation is as follow:

function [Q, qMap] = mef_ssim(imgSeq, fI,  K, window) 

if (nargin < 2 || nargin > 4) % restrict # of params to be in range [2, 4]
   Q = -Inf;
   qMap = Inf;
   return;
end

if (~exist('K', 'var'))
   K = 0.03;
end

if (~exist('window', 'var'))
   window = fspecial('gaussian', 11, 1.5); % fspecial('gaussian',hsize,sigma) returns a rotationally symmetric Gaussian lowpass filter of size hsize with standard deviation sigma.
end


imgSeq = double(imgSeq);
fI = double(fI);
[s1, s2, s3] = size(imgSeq); % s1, s2 is height and width of imgseq, s3 is number of imgs
wSize = size(window,1);
sWindow = ones(wSize) / wSize^2; % square window used to calculate the distance, ones(size) return an all-one matirx of size (size,size)
bd = floor(wSize/2);
mu = zeros(s1-2*bd, s2-2*bd, s3);
ed = zeros(s1-2*bd, s2-2*bd, s3);
for i = 1:s3 % for all input imgs in input sequence, apply mean filter
    img = squeeze(imgSeq(:,:,i));
    % mu is mean value of the patch u_xk
    mu(:,:,i) = filter2(sWindow, img, 'valid'); % valid means no zero-padding, thus the size changes
    muSq = mu(:,:,i) .* mu(:,:,i);
    sigmaSq = filter2(sWindow, img.*img, 'valid') - muSq;
    ed(:,:,i) =  sqrt( max( wSize^2 * sigmaSq, 0 ) ) + 0.001; % add a small constant to avoid instability
end

R = zeros(s1-2*bd,s2-2*bd); % consistency map which could be used as an output if necessary
for i = bd+1:s1-bd
    for j = bd+1:s2-bd
        vecs = reshape(imgSeq(i-bd:i+bd,j-bd:j+bd,:),[wSize*wSize, s3]);
        denominator = 0;
        for k = 1:s3
            denominator = denominator + norm(vecs(:,k) - mu(i-bd,j-bd,k));
        end
        numerator = norm(sum(vecs,2) - mean(sum(vecs,2)));
        R(i-bd,j-bd) = (numerator + eps) / (denominator + eps); % eq (6)
    end
end

R(R > 1) = 1 - eps; % get rid of numerical instability
R(R < 0) = 0 + eps;


p = tan(pi/2 * R); % eq(7)
p( p >  10 ) = 10; % to avoid blow up (large number such as 10 is equivalent to taking maximum)
p = repmat(p,[1,1,s3]);


wMap = (ed / wSize).^p + eps; % to avoid blowing up eq(5)
normalizer = sum(wMap,3);
wMap = wMap ./ repmat(normalizer,[1,1,s3]);
maxEd = max(ed,[],3); % eq(3)

C = (K * 255)^2;
qMap = zeros(s1-2*bd, s2-2*bd); 
for i = bd+1:s1-bd
    for j = bd+1:s2-bd
        blocks = imgSeq(i-bd:i+bd,j-bd:j+bd,:);
        rBlock = zeros(wSize,wSize);
        for k = 1 : s3
            rBlock = rBlock  + wMap(i-bd,j-bd,k) * ( blocks(:,:,k) - mu(i-bd,j-bd,k) ) / ed(i-bd,j-bd,k); % eq(4)
        end  
        if norm(rBlock(:)) > 0
            rBlock = rBlock / norm(rBlock(:)) * maxEd(i-bd,j-bd);
        end
        fBlock = fI(i-bd:i+bd,j-bd:j+bd);
        rVec = rBlock(:);
        fVec = fBlock(:);
        mu1 = sum( window(:) .* rVec );
        mu2 = sum( window(:) .* fVec );
        sigma1Sq = sum( window(:) .* (rVec - mu1).^2 );
        sigma2Sq = sum( window(:) .* (fVec - mu2).^2 );
        sigma12 = sum(  window(:) .* (rVec - mu1) .* (fVec - mu2)  );
        qMap(i-bd,j-bd) = ( 2 * sigma12 + C ) ./ ( sigma1Sq + sigma2Sq + C ); 
    end
end

Q = mean2(qMap); % Average or mean of matrix elements

My pytorch implementation:

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math


def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def mef_ssim(imgSeq, fI, K=0.03):
    fI = fI.squeeze()  # (batch_size, h, w)
    eps = 2.2204e-16
    window = create_window(11, 1).cuda()

    (batch_size, imgSeq_channel, s1, s2) = imgSeq[0].shape
    s3 = len(imgSeq)
    imgSeq_stack = torch.stack(imgSeq).squeeze().unsqueeze(1).cuda().float()  # (img_num, batch_size, h, w)

    wSize = window.shape[2]
    sWindow = torch.ones(window.size()).cuda().float() / wSize ** 2
    bd = wSize // 2
    mu = []
    ed = []
    for i in range(s3):
        img = imgSeq[i].cuda().float()  # (batch_size, channel, h, w)
        mu.append(F.conv2d(img, sWindow, padding = wSize//2, groups = imgSeq_channel))
        muSq = mu[i] * mu[i]
        sigmaSq = F.conv2d(img * img, sWindow, padding = wSize//2, groups = imgSeq_channel) - muSq
        ed.append(torch.sqrt(wSize ** 2 * sigmaSq))
        ed[i][ed[i]<0.001] = 0.001
    ed = torch.stack(ed).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
    mu = torch.stack(mu).squeeze().unsqueeze(1)  # (img_num, batch_size, h, w)
    x_hat = (imgSeq_stack - mu).float()

    SumFilter = torch.ones(window.size()).cuda().float()
    numerator = torch.zeros((x_hat.size())).cuda()
    denominator = torch.zeros((x_hat.size())).cuda()
    for i in range(s3):
        denominator = denominator + torch.sqrt(F.conv2d(x_hat**2, SumFilter, padding = wSize//2, groups = imgSeq_channel))
    numerator = torch.sqrt(F.conv2d((torch.sum(x_hat,0)**2).unsqueeze(0), SumFilter, padding = wSize//2, groups = imgSeq_channel))
    R = (numerator + eps) / (denominator+eps)

    # R = torch.zeros((batch_size, s1 - 2 * bd, s2 - 2 * bd)).cuda()  # (batch_size, h, w)
    # for batch in range(batch_size):
    #     for i in range(bd, s1 - bd):
    #         for j in range(bd, s2 - bd):
    #             vecs = torch.reshape(imgSeq_stack[:, batch, i - bd: i + bd + 1, j - bd : j + bd + 1], (wSize * wSize, s3))
    #             denominator = 0
    #             for k in range(s3):
    #                 denominator = denominator + torch.norm(vecs[:, k] - mu[k, batch, i - bd - 1, j - bd - 1])
    #             numerator = torch.norm(torch.sum(vecs, 1) - torch.mean(torch.sum(vecs, 1)))
    #             R[batch, i - bd - 1, j - bd - 1] = (numerator + eps) / (denominator + eps)

    R[R > 1] = 1 - eps
    R[R < 0] = 0 + eps

    p = torch.tan( 1.57 * R)
    p[p > 10] = 10
    p_rep = []
    for i in range(s3):
        p_rep.append(p)
    p = torch.stack(p_rep)  # (img_num, batch_size, h, w)

    wMap = (ed / wSize) ** p + eps  # (img_num, batch_size, h, w)
    normalizer = torch.sum(wMap, 0)
    normalizer_rep = []
    for i in range(s3):
        normalizer_rep.append(normalizer)
    normalizer_rep = torch.stack(normalizer_rep)
    wMap = wMap / normalizer_rep
    maxEd, maxEd_idx = torch.max(ed, dim=0)

    C = (K * 1) ** 2
    qMap = torch.zeros((x_hat.size())).cuda()

    # UnitFliter = torch.zeros(window.size()).cuda().float()
    # UnitFliter[:,:,5,5] = 1
    # rBlock = torch.sum(F.conv2d(wMap, UnitFliter, padding = wSize//2, groups = imgSeq_channel) *  F.conv2d(x_hat, UnitFliter, padding = wSize//2, groups = imgSeq_channel) /  F.conv2d(ed, UnitFliter, padding = wSize//2, groups = imgSeq_channel), dim=0)
    # if ()
    for batch in range(batch_size):
        for i in range(bd, s1 - bd):
            for j in range(bd, s2 - bd):
                blocks = imgSeq_stack[:, batch, i - bd: i + bd, j-bd : j + bd]
                rBlock = torch.zeros((wSize, wSize))
                for k in range(s3):
                    rBlock = rBlock + wMap[k, batch, i - bd, j - bd] * (
                                blocks[k, :, :, :] - mu[k, batch, i - bd, j - bd]) / ed[k, batch, i - bd, j - bd]

                if (torch.norm(rBlock) > 0):
                    rBlock = rBlock / torch.norm(rBlock) * maxEd[batch, i - bd, j - bd]

                fBlock = fI[batch, i - bd:i + bd, j - bd:j + bd]
                rVec = rBlock.view(-1, 1)
                fVec = fBlock.view(-1, 1)
                windowVec = window.view(-1, 1)
                mu1 = torch.sum(windowVec * rVec)
                mu2 = torch.sum(windowVec * fVec)
                sigma1Sq = torch.sum(windowVec * (rVec - mu1) ** 2)
                sigma2Sq = torch.sum(windowVec * (fVec - mu2) ** 2)
                sigma12 = torch.sum(windowVec * (rVec - mu1) * (fVec - mu2))
                qMap[batch, i - bd, j - bd] = (2 * sigma12 + C) / (sigma1Sq + sigma2Sq + C)

    Q = torch.mean(qMap)

we have a similar implementation in kornia
https://kornia.readthedocs.io/en/latest/losses.html#kornia.losses.SSIM

Thanks for the reply.
But SSIM is actually different from MEFSSIM :sweat_smile: