# Maximum mean discrepancy (MMD) and radial basis function (rbf)

what is a concise and correct way to implement rbf and MMD, considering two vectors? Can rbf function be calculated directly by using torch.norm?

2 Likes

My implementation to compute the MMD between two sets of samples:

Here x and y are batches of images with shape `[B,1,W,H]`

``````	x = x.view(x.size(0), x.size(2) * x.size(3))
y = y.view(y.size(0), y.size(2) * y.size(3))

xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))

K = torch.exp(- self.alpha * (rx.t() + rx - 2*xx))
L = torch.exp(- self.alpha * (ry.t() + ry - 2*yy))
P = torch.exp(- self.alpha * (rx.t() + ry - 2*zz))

beta = (1./(B*(B-1)))
gamma = (2./(B*B))

return beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)
``````

Credit goes to @fmassa for this previous answer.

You can easily change the kernel by changing the definition for K,L,P.

3 Likes

Thank you so much for the fast reply! In this implementation, I guess x and y do not need to have the same B (batch size)? Or B needs to be the same?

With this implementation your batches will have to be the same size, since `rx.t()` and `ry` are square matrices and will end up with different sizes if batch sizes are different (and then `zz` also won’t be square).

The best option is just to randomly sample `N` elements from each of your batches so they are the same size.

1 Like

I see. Thank you again for the explanation.

You may want to check https://github.com/josipd/torch-two-sample a pytorch based library.

1 Like

@Jordan_Campbell , what is the value of alpha here? I mean what could be the suitable and general value of alpha. One more thing, in my case, batch size, B =1. so, what will be the value of beta ?