L1 loss / reconstruction loss turns 0

I am using simple L1 loss for a channel for 3D volumetric segmentation. The target(one-hot encoded) size is:
2, 9, 64, 96, 96
I am using the following codes for that:
(identify_axis function returns [2, 3, 4] for 3D images.)

def channelwiseL1(pred, trgt):
    axis = identify_axis(pred.shape) 
    pred = torch.mean(torch.mean(pred, axis), 0)
    trgt = torch.mean(torch.mean(trgt, axis), 0)
    # return torch.mean(torch.mean(torch.abs(pred, trgt), axis), 0)
    return abs(pred - trgt)[7]

I wanted to implement the loss only on channel 7.
Unfortunately, the L1 loss turns 9 after 3/4 epochs.
I am using the following optimizer:

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))