Customized loss function not converging

Hi,
I have developed a domain adaptation model with customized loss function maximum mean discrepancy (mmd) along with the categorical cross-entropy loss. The simplified version of my code is as below:

class Net(nn.Module):
    def __init__(self, input = 100, hidden =50, n_class=5):
        super(Net, self).__init__()
        self.L1 = nn.Linear(input, hidden) 
        self.L2 = nn.Linear(hidden,n_class)

    
    def forward(self, Source, Target):
        x_src_mmd = self.L1(Source)
        x_tar_mmd = self.L1(Target)
        x_src = self.L2(x_src_mmd)                            
        return x_src, x_src_mmd, x_tar_mmd


class MMD_loss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5.0, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
            bandwidth = bandwidth.float()
        bandwidth /= kernel_mul ** (kernel_num // 2.0)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            with torch.no_grad():
                XX = torch.mean(kernels[:batch_size, :batch_size])
                YY = torch.mean(kernels[batch_size:, batch_size:])
                XY = torch.mean(kernels[:batch_size, batch_size:])
                YX = torch.mean(kernels[batch_size:, :batch_size])
                loss = torch.mean(XX + YY - XY - YX)
            torch.cuda.empty_cache()
            return loss

MMD = MMD_loss()


Source_data = ---
Source_labels = ---
Target_data = ---
Target_labels = ---

model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
factor = 0.5

for i in num_epochs
    for j in num_batches:
        optimizer.zero_grad()
        output_src, output_src_mmd, output_tar_mmd = model(Source_data,Target_data)
        loss_c = criterion(Source_data,Source_labels)
        mmd_loss = MMD(output_src_mmd, output_tar_mmd)
        loss = loss_c + factor*mmd_loss
        loss.backward()
        optimizer.step()
    print('Loss: {:.3f}'.format(loss.item()))

My problem is that during the model training mmd_loss is constantly fluctuating and never converges. But converging in terms of CrossEntropyloss is ok. I’m not sure that mmd_loss backpropagates through the network. Where do you think the problem comes from? Is there any operation in the MMD_loss function that breaks the computation graph?
I would appreciate it if you have any suggestions.

Your loss calculation is wrapped into a no_grad block:

with torch.no_grad():
    XX = torch.mean(kernels[:batch_size, :batch_size])
    YY = torch.mean(kernels[batch_size:, batch_size:])
    XY = torch.mean(kernels[:batch_size, batch_size:])
    YX = torch.mean(kernels[batch_size:, :batch_size])
    loss = torch.mean(XX + YY - XY - YX)

will will make sure that Autograd isn’t tracking these operations.
Could you disable the nn.CrossEntropyLoss, only use the mmd_loss, call backward on it, and check the model parameters for valid gradients?

Many thanks for your reply. Yes, it worked and loss function decreased. But I still have a problem. It seems that on my main dataset, mmd_loss doesn’t work. My goal is to train the model on the source dataset and test this model on the target one. Source and target datasets are from two different domains. As they have different distributions. I’ve used mmd_loss in network training to minimize the discrepancy of source and target datasets. While when I train the network, from the first epoch, the value of mmd_loss is minimized! Although the mmd_loss value is minimized, the results in the classification of the target dataset are not satisfactory.
I don’t know where the problem is! The network structure is as follows:

class ConvNet(nn.Module):
    def __init__(self, num_conv1 =4, size_conv1 =64, num_conv2 =8, size_conv2 =32, len_flatten =num_conv2*out_conv2, num_fc1 =500, num_fc2 = 50, n_class =8):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv1d(1, num_conv1, size_conv1)
        self.conv2 = nn.Conv1d(num_conv1, num_conv2, size_conv2)
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(len_flatten, num_fc1)
        self.fc2 = nn.Linear(num_fc1, num_fc2)
        self.fc3 = nn.Linear(num_fc2, n_class)
        self.lrelu = nn.LeakyReLU(0.01)
        self.drop = nn.Dropout(p = 0.2)
       
    def forward (self, Source, Target):
        x_src = self.drop(self.pool(self.lrelu(self.conv1(Source))))
        x_tar = self.drop(self.pool(self.lrelu(self.conv1(Target))))
        x_src = self.drop(self.pool(self.lrelu(self.conv2(x_src))))
        x_tar = self.drop(self.pool(self.lrelu(self.conv2(x_tar))))
        source_mmd = x_src.view(-1,x_src.shape[1]*x_src.shape[2])
        target_mmd = x_tar.view(-1,x_tar.shape[1]*x_tar.shape[2])           
        x_src = self.lrelu(self.fc1(source_mmd)) 
        x_src = self.lrelu(self.fc2(x_src))
        src_output = self.fc3(x_src)
        return src_output, source_mmd,target_mmd

I cannot see any obvious errors in the code, but I’m unfortunately not familiar enough with the mmd_loss approach to help with the convergence. :confused:

1 Like

Thank you very much for your help. :blush: I apologize for any inconvenience.

Hi Ali_D, have you found a solution to your problem?