My implementation to compute the MMD between two sets of samples:
Here x and y are batches of images with shape [B,1,W,H]
x = x.view(x.size(0), x.size(2) * x.size(3))
y = y.view(y.size(0), y.size(2) * y.size(3))
xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
K = torch.exp(- self.alpha * (rx.t() + rx - 2*xx))
L = torch.exp(- self.alpha * (ry.t() + ry - 2*yy))
P = torch.exp(- self.alpha * (rx.t() + ry - 2*zz))
beta = (1./(B*(B-1)))
gamma = (2./(B*B))
return beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)
Credit goes to @fmassa for this previous answer.
You can easily change the kernel by changing the definition for K,L,P.