Hi,
I have developed a domain adaptation model with customized loss function maximum mean discrepancy (mmd) along with the categorical cross-entropy loss. The simplified version of my code is as below:
class Net(nn.Module):
def __init__(self, input = 100, hidden =50, n_class=5):
super(Net, self).__init__()
self.L1 = nn.Linear(input, hidden)
self.L2 = nn.Linear(hidden,n_class)
def forward(self, Source, Target):
x_src_mmd = self.L1(Source)
x_tar_mmd = self.L1(Target)
x_src = self.L2(x_src_mmd)
return x_src, x_src_mmd, x_tar_mmd
class MMD_loss(nn.Module):
def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
super(MMD_loss, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
self.kernel_type = kernel_type
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5.0, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(
int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(
int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth = bandwidth.float()
bandwidth /= kernel_mul ** (kernel_num // 2.0)
bandwidth_list = [bandwidth * (kernel_mul**i)
for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def linear_mmd2(self, f_of_X, f_of_Y):
loss = 0.0
delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
loss = delta.dot(delta.T)
return loss
def forward(self, source, target):
if self.kernel_type == 'linear':
return self.linear_mmd2(source, target)
elif self.kernel_type == 'rbf':
batch_size = int(source.size()[0])
kernels = self.guassian_kernel(
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
with torch.no_grad():
XX = torch.mean(kernels[:batch_size, :batch_size])
YY = torch.mean(kernels[batch_size:, batch_size:])
XY = torch.mean(kernels[:batch_size, batch_size:])
YX = torch.mean(kernels[batch_size:, :batch_size])
loss = torch.mean(XX + YY - XY - YX)
torch.cuda.empty_cache()
return loss
MMD = MMD_loss()
Source_data = ---
Source_labels = ---
Target_data = ---
Target_labels = ---
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
factor = 0.5
for i in num_epochs
for j in num_batches:
optimizer.zero_grad()
output_src, output_src_mmd, output_tar_mmd = model(Source_data,Target_data)
loss_c = criterion(Source_data,Source_labels)
mmd_loss = MMD(output_src_mmd, output_tar_mmd)
loss = loss_c + factor*mmd_loss
loss.backward()
optimizer.step()
print('Loss: {:.3f}'.format(loss.item()))
My problem is that during the model training mmd_loss is constantly fluctuating and never converges. But converging in terms of CrossEntropyloss is ok. I’m not sure that mmd_loss backpropagates through the network. Where do you think the problem comes from? Is there any operation in the MMD_loss function that breaks the computation graph?
I would appreciate it if you have any suggestions.