About KLDivLoss

hi,all
I have two output of conv2d: out_a , out_b , shape is (N, C, H, W) ,
I want to calculate the kl loss of these two feature maps.
I am not quite sure about KLDivLosss calculation process. out_a, out_b sums = 0 on the channel dimension, Is my code correct?
thanks

sfm_out_a = F.log_softmax(out_a, dim=1)
sfm_out_b = F.softmax(out_b, dim=1)

kl_loss = nn.KLDivLoss().forward(sfm_out_a, sfm_out_b)