Mutual information between input training image and difference between two logits

Hello,

I’m trying to replicate the experiment made by Rafael Muller et al. in “When does label smoothing help?” NIPS, 2019

The method is mainly descibed in section 4 and figure 6 for evaluating the effect of label smoothing in network distillation with classification tasks:
"We measure the mutual information between X and Y , where X is a discrete variable representing the index of the training example and Y is continuous representing the difference between two logits (out of K classes). The exact formula used for mutual information approximation is written in page 7.

I’m trying to replicate this experiment using ResNet18 on Cifar-10 dataset.

After each epoch, I run this piece of code to compute the mutual information as described in the paper. However, the computed mutual information final value is totally far away from its feasible range from 0 to log(N):
batchsize = 300, N = 600 (no. of training instances used in Mut. Inf. calculation), and L = 100 (MonteCarlo samples)

        with torch.no_grad():
            #Mean calculation
            for i in range(self.L):
                for batch_idx, (inputs, targets) in enumerate(trainloader_sub_transforms[i]):
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    outputs = (self.net(inputs)).cpu().detach().numpy()
                    outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
                    mu_x[batch_idx*300:batch_idx*300 + len(targets)] += outputs
            mu_x /= self.L
            print('--> Finish Mean Calculation ', mu_x[:10])
            #STD Calculation
            for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = (self.net(inputs)).cpu().detach().numpy()
                outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
                var = np.sum((outputs - mu_x[batch_idx*300:batch_idx*300 + len(targets)]) ** 2)
            var /= self.N
            print('--> Finish VAR Calculation ', var)
            #Mutual Information Calculation
            mutual_info_value = np.zeros(self.N)
            term2 = 0
            for batch_idx, (inputs, targets) in enumerate(trainloader_sub):
                print('batch_idx2 = ', batch_idx)
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = (self.net(inputs)).cpu().detach().numpy()
                outputs = np.absolute(outputs[:, classes_mi[0]] - outputs[:, classes_mi[1]])
                main_term = -(outputs - mu_x[batch_idx*300:batch_idx*300+len(targets)])**2 / (2 * var)
                term2  +=  np.sum(np.exp( main_term ))
                mutual_info_value[batch_idx*300:batch_idx*300 + len(targets)] = main_term
            term2 = np.log(term2)
            mutual_info_value -= term2
            print(mutual_info_value.shape, mutual_info_value[:10])
            mutual_info_value = np.sum(mutual_info_value)
        sum += mutual_info_value
        print('trial: ', t, ' --> MI = ', mutual_info_value)