Batthacaryya loss

Hello,

I would like to use the Bhattacharyya distance between two saliency maps as a loss function for my network. But if I want to use the autograd engine to back-propagate my loss over the network I need to keep my output as a Variable which doesn’t seems to be possible considering I need to do a element wise multiplication.

Any idea ?

Yours Justin

What about using torch.addcmul?

Fair point but there is a square root :confused:

torch.sqrt

I used the following:

def bhatta_loss(output,target):
out = -torch.log(torch.sum(torch.sqrt(torch.abs(torch.mul(output, target)))))
return out

backward works but I get nan after the first iteration :confused:
Any ideas ?

The only way loss could be nan, would be if torch.mul(output, target) was comprised entirely of zeros. Maybe your model is rapidly learning to produce zero output.

I meant the outpput of the network becomes a nan on the next forward after having backpropagate.

Weird.
Without seeing any code I have no idea what might be happening.

oops sorry:
my network:

class VGGNet(nn.Module):
def init(self):
""“Select conv1_1 ~ conv5_1 activation maps.”""
super(VGGNet, self).init()
self.select = [15,22,29]
self.features = torch.nn.Sequential(
# conv1
torch.nn.Conv2d(in_channels=3,out_channels=64, kernel_size=3,padding=35),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv2
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv3
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv4
torch.nn.Conv2d(256, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2),
# conv5
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, stride=2)
)
self.deconv1 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(256, 128, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(128, 64, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(64, 1, 3, padding=0,stride=1),
torch.nn.ReLU(),
)
self.deconv2 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(512, 256, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(256, 128, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(128, 64, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(64, 1, 3, padding=0,stride=1),
torch.nn.ReLU(),
)
self.deconv3 = torch.nn.Sequential(
torch.nn.ConvTranspose2d(512, 512, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(512, 256, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(256, 128, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(128, 64, 4, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(64, 1, 3, padding=0,stride=1),
torch.nn.ReLU(),
)
self.final_attention_pred = torch.nn.Sequential(
torch.nn.ConvTranspose2d(9, 1, 3, stride=1,padding=1)
)
self._initialize_weights()

def _initialize_weights(self):
    # initializing weights using ImageNet-trained model from PyTorch
    for i, layer in enumerate(models.vgg16(pretrained=True).features):
        if isinstance(layer, torch.nn.Conv2d):
            self.features[i].weight.data = layer.weight.data
            self.features[i].bias.data = layer.bias.data
    for m in self.deconv1:
        if isinstance(m, torch.nn.ConvTranspose2d):
            m.weight.data.normal_(0.0, 0.01)
            m.bias.data.zero_()  
            
    for m in self.deconv1:
        if isinstance(m, torch.nn.ConvTranspose2d):
            m.weight.data.normal_(0.0, 0.01)
            m.bias.data.zero_()  
            
            
    for m in self.deconv2:
        if isinstance(m, torch.nn.ConvTranspose2d):
            m.weight.data.normal_(0.0, 0.01)
            m.bias.data.zero_()  
            
    for m in self.deconv3:
        if isinstance(m, torch.nn.ConvTranspose2d):
            m.weight.data.normal_(0.0, 0.01)
            m.bias.data.zero_()  
            
    for m in self.final_attention_pred:
        if isinstance(m, torch.nn.ConvTranspose2d):
            m.weight.data.normal_(0.0, 0.01)
            m.bias.data.zero_()  

def forward(self, x):
    ##return list of feature map at different size
    features = []
    for i, layer in enumerate(self.features):
        layer.register_backward_hook(printgradnorm)
        if(i in self.select ):
            x = layer(x)
            features.append(x)
        else:
            x = layer(x)
    for i in self.deconv1:
        i.register_backward_hook(printgradnorm)
    
    for i in self.deconv2:
        i.register_backward_hook(printgradnorm)
        
    for i in self.deconv3:
        i.register_backward_hook(printgradnorm)
        
    self.final_attention_pred[0].register_backward_hook(printgradnorm)
        
    saliency = [] 
    m = nn.Sigmoid()
    m1 = nn.Sigmoid()
    m2 = nn.Sigmoid()
    m3 = nn.Sigmoid()
    
   
    m.register_backward_hook(printgradnorm)
    attentionmap1 = self.deconv1(features[0])[:, :, 36:260, 36:260]
    attentionmap1 = [attentionmap1,attentionmap1,attentionmap1]
    attentionmap1 = torch.cat(attentionmap1,1)
    
    attentionmap2 = self.deconv2(features[1])[:, :, 42:266, 42:266]
    attentionmap2 = [attentionmap2,attentionmap2,attentionmap2]
    attentionmap2 = torch.cat(attentionmap2,1)

    
    attentionmap3 = self.deconv3(features[2])[:, :, 54:278, 54:278]
    attentionmap3 = [attentionmap3,attentionmap3,attentionmap3]
    attentionmap3 = torch.cat(attentionmap3,1)


    
    saliency.append(m(attentionmap1))
    
    saliency.append(m1(attentionmap2))

    saliency.append(m2(attentionmap3))

    output_data = torch.cat(saliency,1)
    output  = m3(self.final_attention_pred(output_data))
    return output

and the function of loss is above

I have to say I am not that experienced with image processing using neural networks. You might be better off starting a new thread to ask why the network might output NaN.

An alternative would be to stick a load of print statements in the forward method in order to try and figure out at what point in the calculation the NaNs start appearing.

I encounter the same issue with the same metric. How did you resolve it?