Maximum mean discrepancy (MMD) and radial basis function (rbf)

(Lliu25) #1

what is a concise and correct way to implement rbf and MMD, considering two vectors? Can rbf function be calculated directly by using torch.norm?

(Jordan Campbell) #2

My implementation to compute the MMD between two sets of samples:

Here x and y are batches of images with shape [B,1,W,H]

	x = x.view(x.size(0), x.size(2) * x.size(3))
	y = y.view(y.size(0), y.size(2) * y.size(3))

	xx, yy, zz =,x.t()),,y.t()),,y.t())

	rx = (xx.diag().unsqueeze(0).expand_as(xx))
	ry = (yy.diag().unsqueeze(0).expand_as(yy))

	K = torch.exp(- self.alpha * (rx.t() + rx - 2*xx))
	L = torch.exp(- self.alpha * (ry.t() + ry - 2*yy))
	P = torch.exp(- self.alpha * (rx.t() + ry - 2*zz))

    beta = (1./(B*(B-1)))
    gamma = (2./(B*B)) 

	return beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)

Credit goes to @fmassa for this previous answer.

You can easily change the kernel by changing the definition for K,L,P.

Fastest way to find nearest neighbor for a set of points
(Lliu25) #3

Thank you so much for the fast reply! In this implementation, I guess x and y do not need to have the same B (batch size)? Or B needs to be the same?

(Jordan Campbell) #4

With this implementation your batches will have to be the same size, since rx.t() and ry are square matrices and will end up with different sizes if batch sizes are different (and then zz also won’t be square).

The best option is just to randomly sample N elements from each of your batches so they are the same size.

(Lliu25) #5

I see. Thank you again for the explanation.