Maximum mean discrepancy (MMD) and radial basis function (rbf)

what is a concise and correct way to implement rbf and MMD, considering two vectors? Can rbf function be calculated directly by using torch.norm?

1 Like

My implementation to compute the MMD between two sets of samples:

Here x and y are batches of images with shape [B,1,W,H]

	x = x.view(x.size(0), x.size(2) * x.size(3))
	y = y.view(y.size(0), y.size(2) * y.size(3))

	xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

	rx = (xx.diag().unsqueeze(0).expand_as(xx))
	ry = (yy.diag().unsqueeze(0).expand_as(yy))

	K = torch.exp(- self.alpha * (rx.t() + rx - 2*xx))
	L = torch.exp(- self.alpha * (ry.t() + ry - 2*yy))
	P = torch.exp(- self.alpha * (rx.t() + ry - 2*zz))

    beta = (1./(B*(B-1)))
    gamma = (2./(B*B)) 

	return beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)

Credit goes to @fmassa for this previous answer.

You can easily change the kernel by changing the definition for K,L,P.

2 Likes

Thank you so much for the fast reply! In this implementation, I guess x and y do not need to have the same B (batch size)? Or B needs to be the same?

With this implementation your batches will have to be the same size, since rx.t() and ry are square matrices and will end up with different sizes if batch sizes are different (and then zz also won’t be square).

The best option is just to randomly sample N elements from each of your batches so they are the same size.

1 Like

I see. Thank you again for the explanation.

You may want to check https://github.com/josipd/torch-two-sample a pytorch based library.

1 Like

@Jordan_Campbell , what is the value of alpha here? I mean what could be the suitable and general value of alpha. One more thing, in my case, batch size, B =1. so, what will be the value of beta ?
thanks in advance

I’m not sure comparing distributions makes sense for a single sample.

I found this code when reading the thread:

1 Like