My code is showing up this error:

def compute_kernel(self, x, y):

x_size = x.size(0)

y_size = y.size(0)

dim = x.size(1)

x = x.unsqueeze(1) # (x_size, 1, dim)

y = y.unsqueeze(0) # (1, y_size, dim)

tiled_x = x.expand(x_size, y_size, dim)

tiled_y = y.expand(x_size, y_size, dim)

kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)

return torch.exp(-kernel_input) # (x_size, y_size)

def compute_mmd(self, x, y):

x_kernel = self.compute_kernel(x, x)

y_kernel = self.compute_kernel(y, y)

xy_kernel = self.compute_kernel(x, y)

mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()

return mmd