I am using simple L1 loss for a channel for 3D volumetric segmentation. The target(one-hot encoded) size is:
2, 9, 64, 96, 96
I am using the following codes for that:
(identify_axis
function returns [2, 3, 4]
for 3D images.)
def channelwiseL1(pred, trgt):
axis = identify_axis(pred.shape)
pred = torch.mean(torch.mean(pred, axis), 0)
trgt = torch.mean(torch.mean(trgt, axis), 0)
# return torch.mean(torch.mean(torch.abs(pred, trgt), axis), 0)
return abs(pred - trgt)[7]
I wanted to implement the loss only on channel 7.
Unfortunately, the L1 loss turns 9 after 3/4 epochs.
I am using the following optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))