Consider this example

```
# img : Nx3xHxW
# label: NxHxW
x = generator(img) # x is 10x21x321x321
soft_pred = nn.Softmax2d()(x)
y = discriminator(x) # y is 10x2x321x321
y_1 = y[:,1,:,:]
loss = 0
for i in range(10):
for j in range(321):
for k in range(321):
if y_1[i][j][k] > 0.3:
c = label[i][j][k]
loss += -log(soft_pred[i][c][j][k])
```

Is there a way I can use the standard loss function (NLLLoss2d) here?