Here are ResU-Net code what I used, could you help check and I’m grateful for your answer:
train:
def train(model, train_loader, optimizer, loss_func, n_labels, alpha):
print("=======Epoch:{}=======lr:{}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
model.train()
train_loss = metrics.LossAverage()
train_dice = metrics.DiceAverage(n_labels)
for idx, (data, target) in enumerate(train_loader):
torch.cuda.empty_cache()
data, target = data.float(), target.long()
target = common.to_one_hot_3d(target, n_labels)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss0 = loss_func(output[0], target) + torch.nn.functional.binary_cross_entropy(output[0], target)
loss1 = loss_func(output[1], target) + torch.nn.functional.binary_cross_entropy(output[1], target)
loss2 = loss_func(output[2], target) + torch.nn.functional.binary_cross_entropy(output[2], target)
loss3 = loss_func(output[3], target) + torch.nn.functional.binary_cross_entropy(output[3], target)
loss = loss3 + alpha * (loss0 + loss1 + loss2)
loss.backward()
optimizer.step()
train_loss.update(loss.item(), data.size(0))
train_dice.update(output[3], target)
val_log = OrderedDict({'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1]})
if n_labels == 3: val_log.update({'Train_dice_tumor': train_dice.avg[2]})
return val_log
loss:
class TverskyLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
smooth = 1
dice = 0.
for i in range(pred.size(1)):
dice += (pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) / (
(pred[:, i] * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) +
0.3 * (pred[:, i] * (1 - target[:, i])).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * (
(1 - pred[:, i]) * target[:, i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth)
dice = dice / pred.size(1)
return torch.clamp((1 - dice).mean(), 0, 2)
model:
class ResUNet(nn.Module):
def __init__(self, in_channel=1, out_channel=2, training=True):
super().__init__()
self.training = training
self.dorp_rate = 0.2
self.conv = nn.Conv3d(in_channel, 16, 1, 1)
self.encoder_stage1 = nn.Sequential(
nn.Conv3d(in_channel, 16, 3, 1, padding=1),
nn.PReLU(16),
nn.Conv3d(16, 16, 3, 1, padding=1),
nn.PReLU(16),
)
self.encoder_stage2 = nn.Sequential(
nn.Conv3d(32, 32, 3, 1, padding=1),
nn.PReLU(32),
nn.Conv3d(32, 32, 3, 1, padding=1),
nn.PReLU(32),
nn.Conv3d(32, 32, 3, 1, padding=1),
nn.PReLU(32),
)
self.encoder_stage3 = nn.Sequential(
nn.Conv3d(64, 64, 3, 1, padding=1),
nn.PReLU(64),
nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2),
nn.PReLU(64),
nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4),
nn.PReLU(64),
)
self.encoder_stage4 = nn.Sequential(
nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3),
nn.PReLU(128),
nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4),
nn.PReLU(128),
nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5),
nn.PReLU(128),
)
self.decoder_stage1 = nn.Sequential(
nn.Conv3d(128, 256, 3, 1, padding=1),
nn.PReLU(256),
nn.Conv3d(256, 256, 3, 1, padding=1),
nn.PReLU(256),
nn.Conv3d(256, 256, 3, 1, padding=1),
nn.PReLU(256),
)
self.decoder_stage2 = nn.Sequential(
nn.Conv3d(128 + 64, 128, 3, 1, padding=1),
nn.PReLU(128),
nn.Conv3d(128, 128, 3, 1, padding=1),
nn.PReLU(128),
nn.Conv3d(128, 128, 3, 1, padding=1),
nn.PReLU(128),
)
self.decoder_stage3 = nn.Sequential(
nn.Conv3d(64 + 32, 64, 3, 1, padding=1),
nn.PReLU(64),
nn.Conv3d(64, 64, 3, 1, padding=1),
nn.PReLU(64),
nn.Conv3d(64, 64, 3, 1, padding=1),
nn.PReLU(64),
)
self.decoder_stage4 = nn.Sequential(
nn.Conv3d(32 + 16, 32, 3, 1, padding=1),
nn.PReLU(32),
nn.Conv3d(32, 32, 3, 1, padding=1),
nn.PReLU(32),
)
self.down_conv1 = nn.Sequential(
nn.Conv3d(16, 32, 2, 2),
nn.PReLU(32)
)
self.down_conv2 = nn.Sequential(
nn.Conv3d(32, 64, 2, 2),
nn.PReLU(64)
)
self.down_conv3 = nn.Sequential(
nn.Conv3d(64, 128, 2, 2),
nn.PReLU(128)
)
self.down_conv4 = nn.Sequential(
nn.Conv3d(128, 256, 3, 1, padding=1),
nn.PReLU(256)
)
self.up_conv2 = nn.Sequential(
nn.ConvTranspose3d(256, 128, 2, 2),
nn.PReLU(128)
)
self.up_conv3 = nn.Sequential(
nn.ConvTranspose3d(128, 64, 2, 2),
nn.PReLU(64)
)
self.up_conv4 = nn.Sequential(
nn.ConvTranspose3d(64, 32, 2, 2),
nn.PReLU(32)
)
self.map4 = nn.Sequential(
nn.Conv3d(32, out_channel, 1, 1),
nn.Upsample(scale_factor=(1, 1, 1), mode='trilinear', align_corners=False),
nn.Softmax(dim=1)
)
self.map3 = nn.Sequential(
nn.Conv3d(64, out_channel, 1, 1),
nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=False),
nn.Softmax(dim=1)
)
self.map2 = nn.Sequential(
nn.Conv3d(128, out_channel, 1, 1),
nn.Upsample(scale_factor=(4, 4, 4), mode='trilinear', align_corners=False),
nn.Softmax(dim=1)
)
self.map1 = nn.Sequential(
nn.Conv3d(256, out_channel, 1, 1),
nn.Upsample(scale_factor=(8, 8, 8), mode='trilinear', align_corners=False),
nn.Softmax(dim=1)
)
def forward(self, inputs):
long_range1 = self.encoder_stage1(inputs) + self.conv(inputs)
short_range1 = self.down_conv1(long_range1)
long_range2 = self.encoder_stage2(short_range1) + short_range1
long_range2 = F.dropout(long_range2, self.dorp_rate, self.training)
short_range2 = self.down_conv2(long_range2)
long_range3 = self.encoder_stage3(short_range2) + short_range2
long_range3 = F.dropout(long_range3, self.dorp_rate, self.training)
short_range3 = self.down_conv3(long_range3)
long_range4 = self.encoder_stage4(short_range3) + short_range3
long_range4 = F.dropout(long_range4, self.dorp_rate, self.training)
short_range4 = self.down_conv4(long_range4)
outputs = self.decoder_stage1(long_range4) + short_range4
outputs = F.dropout(outputs, self.dorp_rate, self.training)
output1 = self.map1(outputs)
short_range6 = self.up_conv2(outputs)
outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6
outputs = F.dropout(outputs, self.dorp_rate, self.training)
output2 = self.map2(outputs)
short_range7 = self.up_conv3(outputs)
outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7
outputs = F.dropout(outputs, self.dorp_rate, self.training)
output3 = self.map3(outputs)
short_range8 = self.up_conv4(outputs)
outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8
output4 = self.map4(outputs)
if self.training is True:
return output1, output2, output3, output4
else:
return output4