Create pixel wise vectors from multiple feature maps

Hello.
I’m trying to implement the Unet paper from scratch but I’m struggling implementing what they call the “energy function”.
Citing the paper:

So far this is what I have:

class EnergyFunction(nn.Module):
    def __init__(self):
        super(EnergyFunction, self).__init__()

    def forward(self, logits):
        # logits is of shape [batch_size, feature_channels, height, width]
        # Here: [4, 2, 388, 388]
        r = F.softmax(logits)
        # r is of shape: [4, 2, 388, 388]
        return r

But according to the paper I should pass in each of my 2 388x388 matrices (A and B) pixels to a softmax function where each resulting “pixel” is a vector of {[A(0,0), B(0, 0)], [A(0, 1), B(0, 1)]…} and so on. So basically I should have something like this: {softmax([A(0,0), B(0, 0)]), softmax([A(0, 1), B(0, 1)])…} right?
And the resulting shape of r should be [4, 1, 388, 388] right?
I really have no idea how to do that.
Thank you for your help.

this is what they mean:

# input is 4 x 2 x 388 x 388 raw ouputs of linear.

logits = F.log_softmax(input)

# Here, an undocumented thing is that if input is 4D, softmax is applied in dimension 1. i.e. for each pixel location, you have a softmax applied over channels.

m = nn.NLLLoss2d()

loss = m(logits, target)
# here target is of shape 4 x 1 x 388 x 388

Thank you but I don’t quite get it. Why would we use log_softmax instead of softmax? And why do we need the targets? According to the paper we only need the output feature channels of our neural network right? Later we use cross-entropy on logits and target to penalize each position from 1.
I really want to replicate the paper with a 100% authentic code and not use new techniques discovered later after the paper was released.
Following your instructions I now have:

class EnergyFunction(nn.Module):
    def __init__(self):
        super(EnergyFunction, self).__init__()

    def forward(self, logits, target):
        # logits is of shape [batch_size, feature_channels, height, width]
        logits = F.log_softmax(logits)
        m = nn.NLLLoss2d()

        # At this point:
        # logits.size() == [4, 2, 388, 388]
        # target.size() == [4, 388, 388]
        loss = m(logits, target.long())
        return loss

But i get a 1D loss as a result, this is not really what I want. My goal here is really to map the maths formulas to the code. Thank you