Why the ResU-net didn't work (get trained)?

I adopted a network named ResU-net and modified extremely few codes, but only used our image dataset. However, I tried different learning rate, or different cropping, but failed to train this network normally, I monitored that the CE/Dice loss of this network changed hardly, as unexpected. In the same way, the train_dice and val_loss were terribly bad like this in the following:
|epoch|Train_Loss|Train_dice_|Val_Loss|Val_dice_|
|1|1.8275|0.405|0.504|0.4032|
|2|1.8266|0.4067|0.504|0.4032|
|3|1.8266|0.4067|0.504|0.4032|
|4|1.8266|0.4067|0.504|0.4032|
|5|1.8266|0.4067|0.504|0.4032|
|6|1.8266|0.4067|0.504|0.4032|
|7|1.8266|0.4067|0.504|0.4032|
|8|1.8266|0.4067|0.504|0.4032|
|9|1.8266|0.4067|0.504|0.4032|
|10|1.8266|0.4067|0.504|0.4032|

|100|1.3404|0.4067|0.504|0.4032|
|101|1.3404|0.4067|0.504|0.4032|
|102|1.3404|0.4067|0.504|0.4032|
|103|1.3404|0.4067|0.504|0.4032|
|104|1.3404|0.4067|0.504|0.4032|
|105|1.3404|0.4067|0.504|0.4032|
|106|1.3404|0.4067|0.504|0.4032|
|107|1.3404|0.4067|0.504|0.4032|
|108|1.3404|0.4067|0.504|0.4032|
|109|1.3404|0.4067|0.504|0.4032|
|110|1.3404|0.4067|0.504|0.4032|
|111|1.3404|0.4067|0.504|0.4032|
|112|1.3404|0.4067|0.504|0.4032|

|287|0.964|0.4067|0.504|0.4032|
|288|0.964|0.4067|0.504|0.4032|
|289|0.964|0.4067|0.504|0.4032|
|290|0.964|0.4067|0.504|0.4032|
|291|0.964|0.4067|0.504|0.4032|
|292|0.964|0.4067|0.504|0.4032|
|293|0.964|0.4067|0.504|0.4032|
|294|0.964|0.4067|0.504|0.4032|
|295|0.964|0.4067|0.504|0.4032|
|296|0.964|0.4067|0.504|0.4032|
|297|0.964|0.4067|0.504|0.4032|
|298|0.964|0.4067|0.504|0.4032|
|299|0.964|0.4067|0.504|0.4032|
|300|0.964|0.4067|0.504|0.4032|
|301|0.9373|0.4067|0.504|0.4032|
|302|0.9373|0.4067|0.504|0.4032|
|303|0.9373|0.4067|0.504|0.4032|

What’s wrong with this ResU-net? Anybody can give some advice, thank you so much.

There could be dozens of reasons why you’re seeing this behavior, including but not limited to how you are calculating your statistics.

Seems your train loss is decreasing after some discrete number of epochs.

Please share your code, including how you calculate your metrics. Be sure to wrap your code with three backticks, like this:

IMG_20230605_124304

Here are ResU-Net code what I used, could you help check and I’m grateful for your answer:

train:
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
    print("=======Epoch:{}=======lr:{}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
    model.train()
    train_loss = metrics.LossAverage()
    train_dice = metrics.DiceAverage(n_labels)

    for idx, (data, target) in enumerate(train_loader):
        torch.cuda.empty_cache()
        data, target = data.float(), target.long()
        target = common.to_one_hot_3d(target, n_labels)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss0 = loss_func(output[0], target) + torch.nn.functional.binary_cross_entropy(output[0], target)
        loss1 = loss_func(output[1], target) + torch.nn.functional.binary_cross_entropy(output[1], target)
        loss2 = loss_func(output[2], target) + torch.nn.functional.binary_cross_entropy(output[2], target)
        loss3 = loss_func(output[3], target) + torch.nn.functional.binary_cross_entropy(output[3], target)

        loss = loss3 + alpha * (loss0 + loss1 + loss2)
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), data.size(0))
        train_dice.update(output[3], target)

    val_log = OrderedDict({'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1]})
    if n_labels == 3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
    return val_log

loss:
class TverskyLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        smooth = 1

        dice = 0.

        for i in range(pred.size(1)):
            dice += (pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) / (
                        (pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) +
                        0.3 * (pred[:, i] * (1 - target[:, i])).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * (
                                    (1 - pred[:, i]) * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)

        dice = dice / pred.size(1)
        return torch.clamp((1 - dice).mean(), 0, 2)

model:
class ResUNet(nn.Module):
    def __init__(self, in_channel=1, out_channel=2, training=True):
        super().__init__()

        self.training = training
        self.dorp_rate = 0.2

        self.conv = nn.Conv3d(in_channel, 16, 1, 1)

        self.encoder_stage1 = nn.Sequential(
            nn.Conv3d(in_channel, 16, 3, 1, padding=1),
            nn.PReLU(16),

            nn.Conv3d(16, 16, 3, 1, padding=1),
            nn.PReLU(16),
        )

        self.encoder_stage2 = nn.Sequential(
            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),

            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),

            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),
        )

        self.encoder_stage3 = nn.Sequential(
            nn.Conv3d(64, 64, 3, 1, padding=1),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4),
            nn.PReLU(64),
        )

        self.encoder_stage4 = nn.Sequential(
            nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5),
            nn.PReLU(128),
        )

        self.decoder_stage1 = nn.Sequential(
            nn.Conv3d(128, 256, 3, 1, padding=1),
            nn.PReLU(256),

            nn.Conv3d(256, 256, 3, 1, padding=1),
            nn.PReLU(256),

            nn.Conv3d(256, 256, 3, 1, padding=1),
            nn.PReLU(256),
        )

        self.decoder_stage2 = nn.Sequential(
            nn.Conv3d(128 + 64, 128, 3, 1, padding=1),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=1),
            nn.PReLU(128),

            nn.Conv3d(128, 128, 3, 1, padding=1),
            nn.PReLU(128),
        )

        self.decoder_stage3 = nn.Sequential(
            nn.Conv3d(64 + 32, 64, 3, 1, padding=1),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=1),
            nn.PReLU(64),

            nn.Conv3d(64, 64, 3, 1, padding=1),
            nn.PReLU(64),
        )

        self.decoder_stage4 = nn.Sequential(
            nn.Conv3d(32 + 16, 32, 3, 1, padding=1),
            nn.PReLU(32),

            nn.Conv3d(32, 32, 3, 1, padding=1),
            nn.PReLU(32),
        )

        self.down_conv1 = nn.Sequential(
            nn.Conv3d(16, 32, 2, 2),
            nn.PReLU(32)
        )

        self.down_conv2 = nn.Sequential(
            nn.Conv3d(32, 64, 2, 2),
            nn.PReLU(64)
        )

        self.down_conv3 = nn.Sequential(
            nn.Conv3d(64, 128, 2, 2),
            nn.PReLU(128)
        )

        self.down_conv4 = nn.Sequential(
            nn.Conv3d(128, 256, 3, 1, padding=1),
            nn.PReLU(256)
        )

        self.up_conv2 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, 2, 2),
            nn.PReLU(128)
        )

        self.up_conv3 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, 2, 2),
            nn.PReLU(64)
        )

        self.up_conv4 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, 2, 2),
            nn.PReLU(32)
        )

        self.map4 = nn.Sequential(
            nn.Conv3d(32, out_channel, 1, 1),
            nn.Upsample(scale_factor=(1, 1, 1), mode='trilinear', align_corners=False),
            nn.Softmax(dim=1)
        )

        self.map3 = nn.Sequential(
            nn.Conv3d(64, out_channel, 1, 1),
            nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=False),
            nn.Softmax(dim=1)
        )

        self.map2 = nn.Sequential(
            nn.Conv3d(128, out_channel, 1, 1),
            nn.Upsample(scale_factor=(4, 4, 4), mode='trilinear', align_corners=False),

            nn.Softmax(dim=1)
        )

        self.map1 = nn.Sequential(
            nn.Conv3d(256, out_channel, 1, 1),
            nn.Upsample(scale_factor=(8, 8, 8), mode='trilinear', align_corners=False),
            nn.Softmax(dim=1)
        )


    def forward(self, inputs):

        long_range1 = self.encoder_stage1(inputs) + self.conv(inputs)

        short_range1 = self.down_conv1(long_range1)

        long_range2 = self.encoder_stage2(short_range1) + short_range1
        long_range2 = F.dropout(long_range2, self.dorp_rate, self.training)

        short_range2 = self.down_conv2(long_range2)

        long_range3 = self.encoder_stage3(short_range2) + short_range2
        long_range3 = F.dropout(long_range3, self.dorp_rate, self.training)

        short_range3 = self.down_conv3(long_range3)

        long_range4 = self.encoder_stage4(short_range3) + short_range3
        long_range4 = F.dropout(long_range4, self.dorp_rate, self.training)

        short_range4 = self.down_conv4(long_range4)

        outputs = self.decoder_stage1(long_range4) + short_range4
        outputs = F.dropout(outputs, self.dorp_rate, self.training)

        output1 = self.map1(outputs)

        short_range6 = self.up_conv2(outputs)

        outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6
        outputs = F.dropout(outputs, self.dorp_rate, self.training)

        output2 = self.map2(outputs)

        short_range7 = self.up_conv3(outputs)

        outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7
        outputs = F.dropout(outputs, self.dorp_rate, self.training)

        output3 = self.map3(outputs)

        short_range8 = self.up_conv4(outputs)

        outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8

        output4 = self.map4(outputs)

        if self.training is True:
            return output1, output2, output3, output4
        else:
            return output4

I see a few issues with your TverskyLoss function.

Noteably, you’re only adding smooth to the denominator and not to the numerator. That might not be a standard implementation.

Additionally, you can simplify your code for that function to the following:

class TverskyLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        smooth = 1
        dims = (2,3,4)
        dice = ((pred*target).sum(dims) + smooth)/\
                        ((pred*target).sum(dims) +
                             0.3 * (pred*(1 - target)).sum(dims) +
                             0.7 * ((1 - pred)*target).sum(dims) +
                             smooth)
        return torch.clamp((1 - dice.mean()), 0, 2)

#old function, just adding smooth on the numerator
class TverskyLoss_v0(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        smooth = 1

        dice = 0.

        for i in range(pred.size(1)):
            dice += ((pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1)+smooth) / (
                        (pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) +
                        0.3 * (pred[:, i] * (1 - target[:, i])).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * (
                                    (1 - pred[:, i]) * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)

        dice = dice / pred.size(1)
        return torch.clamp((1 - dice).mean(), 0, 2)


data1 = torch.rand((4, 3, 320, 320, 320), device=device)
data2 = torch.rand((4, 3, 320, 320, 320), device=device)

tloss1=TverskyLoss()
tloss2 = TverskyLoss2()

#test that both give the same values
print(tloss1(data1, data2))
print(tloss2(data1, data2))

# test the speed of both functions over 100 iterations
if __name__ == '__main__':
    import timeit
    print(timeit.timeit("tloss1(data1, data2)", number=100, setup="from __main__ import data1, data2, tloss1, tloss2"))
    print(timeit.timeit("tloss2(data1, data2)", number=100, setup="from __main__ import data1, data2, tloss1, tloss2"))

After adding the smooth in the numerator on your code, the above will demonstrate that they are identical. I’ve also added code for a timeit test. Was getting around 0.028s and 6.1s respectively on a RTX 3090, for a speedup of 218x with the simplified code.

I got. Thank you for your reply timely. By the way, do you mean only this code about TverskyLoss is inferior? and else

I did not notice any other issues. But I also do not have access to your metrics.LossAverage() etc.

I’m guessing this is from the lee-zq Github repo here. You may want to try contacting him there or on his project page link found there and see what to expect with his code regarding the metrics.

I am grateful for your wise counsel. Here attached as Dice Metric, could you help again:

 class DiceAverage(object):
    """Computes and stores the average and current value for calculate average loss"""

    def __init__(self, class_num):
        self.class_num = class_num
        self.reset()

    def reset(self):
        self.value = np.asarray([0] * self.class_num, dtype='float64')
        self.avg = np.asarray([0] * self.class_num, dtype='float64')
        self.sum = np.asarray([0] * self.class_num, dtype='float64')
        self.count = 0

    def update(self, logits, targets):
        self.value = DiceAverage.get_dices(logits, targets)
        self.sum += self.value
        self.count += 1
        self.avg = np.around(self.sum / self.count, 4)

    @staticmethod
    def get_dices(logits, targets):
        dices = []
        for class_index in range(targets.size()[1]):
            inter = torch.sum(logits[:, class_index, :, :, :] * targets[:, class_index, :, :, :])
            union = torch.sum(logits[:, class_index, :, :, :]) + torch.sum(targets[:, class_index, :, :, :])
            dice = (2. * inter + 1) / (union + 1)
            dices.append(dice.item())
        return np.asarray(dices)

Hi, Dear Johnson,
Is there any question of our attached DiceAverage procedure?
In addition, I found the 3d ResU-net unable to efficiently converge when confronted to a little complex image dataset, where the learning rate was taken 1e-3 and decay rate a half of previous value after per ten epochs, whether the decay rate dropped too quick or not? In this case of after 200 epochs, the network loss seems to be unchanged and dice index almost in the same manner (even got increased on the margin). In the following, surprisingly like this:
epoch|Train_Loss|Train_dice|Val_Loss|Val_dice

|104|0.1226|0.9366|0.1608|0.6973|
|105|0.1291|0.9319|0.1551|0.6932|
|106|0.1493|0.9161|0.2468|0.5597|
|107|4.3836|0.1521|0.5095|0|
|108|8.5307|0.0033|0.5094|0|
|109|8.7633|0.0008|0.5094|0|
|110|8.7956|0.0005|0.5094|0|

What’s the cause that incurred this result? Should I reset initial learning rate, or as a alternative to use different optimization method such as ADMM? Could you give suggestions? Thanks.

That appears to be that the learning rate is a bit high for the recent update. Have you tried a lower learning rate? You may also need to include a constant multiplier in your train function to make the Tversky loss and bce loss more balanced. Or try adjusting the smooth parameter in your Tversky loss. I’ve also seen the 0.3 and 0.7 in your Tversky loss function adjusted in other implementations as a hyperparameter. In such a case, they are replaced with alpha and (1-alpha) respectively.

As for 3D UNets used in segmentation, that’s not something I’ve worked on before. So, aside from pointing out obvious issues, I’m afraid I wouldn’t be of much help without downloading the dataset and working on it myself.

Your best bet would be to contact the repo creator and make an issue on his page to share the problem you’re having or try to reach him in the comments on his website for the project.

Alternatively, you could look at other similar 3d UNet segmentation implementations like: GitHub - wolny/pytorch-3dunet: 3D U-Net model for volumetric semantic segmentation written in pytorch

Thank you for your sincere words and always enthusiasm. I tried again.