Help with histogram and loss.backward()

I’m trying to code a loss function that calculates the mutual information (MI) between two images. This requires, among other things, calculating the histograms of the images. My code correctly calculates the MI between two images, but I noticed loss wasn’t improving and my gradients weren’t updating. I figured something was dropping off the graph and I think I’ve traced it back to my histogram calculation.

Here is some example code to reproduce the issue:

#Random dataset
dist1 = torch.randint(0,16,(100,)).float().requires_grad_(True).cuda()

#Do Stuff to Data to create graph
conv = torch.nn.Conv1d(1,1,1).cuda()(dist1[(None,)*2])

#Calculate histogram
if conv.min().item() < 0:
    conv = conv - conv.min()
bins = 10
conv_binned = torch.trunc(conv * bins/conv.max().item())
ones = torch.ones_like(conv_binned, requires_grad=True)
zeros = torch.zeros_like(conv_binned, requires_grad=True)
hist = torch.tensor([torch.where(conv_binned == bin, ones, zeros).sum() 
for bin in range(bins)], requires_grad=True)

print(hist.grad_fn)
print(conv_binned.grad_fn)

None
<TruncBackward object at 0x7fa4483a0b38>

I think this is the issue. Any ideas how to fix this? Thanks.

Histograms are discrete by nature, so the gradients don’t make much sense / would be zero.
So you cannot use this in your objective function much as you cannot optimise classification accuracy.

Best regards

Thomas

Thanks for the reply Thomas!

It makes sense that you can’t get a gradient from a histogram. However, mutual information is really an ideal metric for my network to train on. Do you have any suggestions/thoughts on a way to implement this? Or is it really not compatible with pytorch?

Thanks in advance.

It’s not a matter of “pytorch”. Deep learning works with gradients, just that. The most you can do is to use that metric to take decisions on something but not to use it as loss itself.

It’s really a more fundamental thing. There have been some papers maximising MI, e.g. Devon Hjelm (@erroneus) has https://arxiv.org/abs/1808.06670 and https://github.com/rdevon/DIM (last I looked, it didn’t run out of the box, though).

Best regards

Thomas

1 Like

Hi @tom, DIM should definitely run out of the box now! :smile:

Right now, mutual information neural estimator (MINE, http://proceedings.mlr.press/v80/belghazi18a.html) isn’t added to the public repo, but it’s just a matter of me having time to clean the dev code and add it in. It represents a lower-bound to MI (asymptotically unbiased), but other options exist:
Fenchel dual to the KL (see https://arxiv.org/abs/1606.00709)
Fenchel dual to JSD (we did this in DIM to maximize MI, see our appendix A1, but see http://bayesiandeeplearning.org/2018/papers/136.pdf on using JSD to train an estimator)
A noise contrastive based estimator found in CPC (https://arxiv.org/abs/1807.03748).

@thompa2 if you have any questions about these estimators or the DIM code, let me know (either here or in the DIM issues). I also plan to add these estimators to the DIM code soon, depending on time and demand.

-devon

3 Likes

Hi @erroneus,

Finally got some time to dig through your DIM code and the paper. Very impressive and very thorough.

The code is obviously highly modular, but it makes it a little difficult to track how the DIM network gets constructed and the loss calculated. I made a small, hard-coded network to implement the local/global DIM loss calculation for a specific tensor shape. This was written from your posted DIM code and the various figures in the paper (and appendix). Can you take a look and let me know if I’m understanding things correctly?
(I’m fairly new to all this, so apologies for any obviously dumb mistakes)

Input A & B: 256 x 256 x 256 tensor representing a 3D image volume (“Atlas” and “Moving” volumes)
Volumes are greyscale (single channel).

When I use actual image data, the network trains, and seems to generate numbers that accurately correlate with hard-calculated mutual information. Which is fantastic, but only if I’m actually doing this right. Appreciate any insight/critiques/etc. Thanks!

A = torch.randint(0, 256, (2, 1, 256, 256, 256)).float().requires_grad_(True)
B = torch.randint(0, 256, (2, 1, 256, 256, 256)).float().requires_grad_(True)

def get_positive_expectation(p_samples, average=True):
    #Measure = JSD (Simplified from DIM Code for clarity)
    log_2 = math.log(2.)
    Ep = log_2 - F.softplus(-p_samples)  # Note JSD will be shifted
    if average:
        return Ep.mean()
    else:
        return Ep

def get_negative_expectation(q_samples, average=True):
    #Measure = JSD (Simplified from DIM Code for clarity)
    log_2 = math.log(2.)
    Eq = F.softplus(-q_samples) + q_samples - log_2  # Note JSD will be shifted
    if average:
        return Eq.mean()
    else:
        return Eq

def loss_calc(lmap, gmap):
    #The fenchel_dual_loss from the DIM code
    #Reshape tensors dims to (N, Channels, chunks)
    lmap = lmap.reshape(2,128,-1)
    gmap = gmap.squeeze()
    
    N, units, n_locals = lmap.size()
    n_multis = gmap.size(2)

    # First we make the input tensors the right shape.
    l = lmap.view(N, units, n_locals)
    l = lmap.permute(0, 2, 1)
    l = lmap.reshape(-1, units)

    m = gmap.view(N, units, n_multis)
    m = gmap.permute(0, 2, 1)
    m = gmap.reshape(-1, units)
    
    u = torch.mm(m, l.t())
    u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
    
    mask = torch.eye(N).to(l.device)
    n_mask = 1 - mask
    
    E_pos = get_positive_expectation(u, average=False).mean(2).mean(2)
    E_neg = get_negative_expectation(u, average=False).mean(2).mean(2)
    
    E_pos = (E_pos * mask).sum() / mask.sum()
    E_neg = (E_neg * n_mask).sum() / n_mask.sum()
    loss = E_neg - E_pos
    
    return loss

class Mixed_Dim(torch.nn.Module):
    def __init__(self):
        super(Mixed_Dim, self).__init__()
        #Local Feature Map: Local feature map of size [N, 128, 8, 8, 8]
        self.lmapnet = Sequential(Conv3d(1, 8, 2, 2), ReLU(), BatchNorm3d(8),
                                  Conv3d(8, 16, 2, 2), ReLU(), BatchNorm3d(16),
                                  Conv3d(16, 32, 2, 2), ReLU(), BatchNorm3d(32),
                                  Conv3d(32, 64, 2, 2), ReLU(), BatchNorm3d(64),
                                  Conv3d(64, 128, 2, 2), ReLU(), BatchNorm3d(128) 
                                 )
        
        #Global Feature Map: Global feature map of size [N, 128, 1, 1, 1]
        self.gmapnet = Sequential(Conv3d(128, 128, 2, 2), ReLU(), BatchNorm3d(128),
                                  Conv3d(128, 128, 2, 2), ReLU(), BatchNorm3d(128),
                                  Conv3d(128, 128, 2, 2), ReLU()
                                 )
        #Per paper, global map is activated:
        self.gfc1 = Sequential(Linear(1, 512), ReLU(), BatchNorm3d(128),
                               Linear(512, 512))
        self.gfc2 = Sequential(Linear(1, 512), ReLU(), BatchNorm3d(128))
        
        #Per paper, local map is activated:
        self.lfc1 = Sequential(Conv3d(128, 128, 1, 1), ReLU(), BatchNorm3d(128),
                               Conv3d(128, 128, 1, 1))
        self.lfc2 = Sequential(Conv3d(128, 128, 1, 1), ReLU(), BatchNorm3d(128))
        self.Laynorm = LayerNorm([128, 8, 8, 8])
        
    def forward(self, moving, atlas):
        #Local feature maps of moving and atlas
        lmap_moving = self.lmapnet(moving)
        lmap_atlas = self.lmapnet(atlas)
        
        #Global feature map of atlas
        gmap_atlas = self.gmapnet(lmap_atlas)
        
        #Encode Global feature map
        gout1, gout2 = self.gfc1(gmap_atlas), self.gfc2(gmap_atlas)
        gmap_atlas_enc = self.gfc1(gmap_atlas) + self.gfc2(gmap_atlas)
        
        #Encode Local feature maps
        lmap_atlas_enc = self.Laynorm(self.lfc1(lmap_atlas) + self.lfc2(lmap_atlas))
        lmap_moving_enc = self.Laynorm(self.lfc1(lmap_moving) + self.lfc2(lmap_moving))
        
        return lmap_atlas_enc, lmap_moving_enc, gmap_atlas_enc
        
model_dim = Mixed_Dim().cuda().train()
optim = torch.optim.Adam(model_dim.parameters(), lr=0.01)
A = A.cuda()
B = B.cuda()
for epoch in range(100):
    loc_atlas, loc_moving, glob_atlas = model_dim(B, A)
    loss1 = loss_calc(loc_moving, glob_atlas)
    loss2 = loss_calc(loc_atlas, glob_atlas)
    loss = (loss2 + loss1)
    model_dim.zero_grad()
    print(loss.item())
    loss.backward()
    optim.step()    

OK, looking at your code, so you’re doing DIM + a cross-modal type of DIM (atlas to moving?). Everything that you’re doing code-wise seems to make sense, but I wouldn’t expect to be able to do anything with the cross-modal part in this case because A and B are drawn independently from each other.

-devon

Also make sure that you’re careful about the receptive field sizes for the local: if they get too large they aren’t “local” anymore, which almost surely makes things work poorly.

Appreciate the feedback @erroneus! I had a question regarding your comment about it not working for the cross-modal part because A and B are drawn independently. I’m working with medical imaging, and A and B are different modalities but are of the same patient, so the distributions shouldn’t be totally independent despite the difference in imaging modality. (see picture)

I borrowed part of your diagram from appendix A.2 Figure 6 for sake of clarity. No plagiarism intended, happy to repost without it if you’d prefer.

Or are you referring to my bit of code:

loss1 = loss_calc(loc_moving, glob_atlas)
loss2 = loss_calc(loc_atlas, glob_atlas)
loss = (loss2 + loss1)

I’m not supposed to concatenate the local feature maps together, right? i.e.

loc_maps = torch.cat((loc_atlas, loc_moving))
loss = loss_calc(loc_maps, glob_atlas) 

I will certainly do some tuning to optimize receptive field size as well. Thanks again!

-Drew

I wouldn’t concat. I would use two different discriminators for those two losses. Each of which use a dot-product based scoring function.

Hi @tom @thompa2
I have a same problem with histc and backward.().
I want to count L1 and L2, Sikhon loss based on the histogram and do backward.(), but it is not possible.

What did you do as a solution to use histogram in your loss?

I would appreciate your help