Here are my codes
For loss
def dice_loss(self,true, logits, eps=1e-7):
"""Computes the Sørensen–Dice loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.
Args:
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.
Returns:
dice_loss: the Sørensen–Dice loss.
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.long()]
true_1_hot = true_1_hot.permute(0, 4, 1, 2 , 3).float()
probas = F.softmax(logits,dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
dice_loss = (2. * intersection / (cardinality + eps)).mean()
return (1 - dice_loss)```
Model
class UNet2D(nn.Module):
def __init__(self,inputChannels,outputChannels,init_features=32):
super(UNet2D, self).__init__()
features = init_features
self.encoder1 = UNet2D._block(inputChannels, features, name="enc1")
self.pool1 = nn.MaxPool3d(kernel_size=3, stride=2)
self.encoder2 = UNet2D._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool3d(kernel_size=3, stride=2)
self.encoder3 = UNet2D._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool3d(kernel_size=3, stride=2)
self.encoder4 = UNet2D._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool3d(kernel_size=3, stride=2)
self.bottleneck = UNet2D._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose3d(
features * 16, features * 8, kernel_size=(3,3,4), stride=2
)
self.decoder4 = UNet2D._block((features * 8)*2 , features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose3d(
features * 8, features * 4, kernel_size=(3,3,4), stride=2
)
self.decoder3 = UNet2D._block((features * 4)*2 , features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose3d(
features * 4, features * 2, kernel_size=(3,3,3), stride=2
)
self.decoder2 = UNet2D._block((features * 2)*2 , features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose3d(
features * 2, features, kernel_size=(4,4,3), stride=2
)
self.decoder1 = UNet2D._block(features, features, name="dec1")
self.conv = nn.Conv3d(
in_channels=features, out_channels=outputChannels, kernel_size=1
)
def forward(self, x):
device='cuda:0'
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
enc5 = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(enc5)
dec4 = torch.cat((dec4, enc4.to(device)), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3.to(device)), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2.to(device)), dim=1)
dec1 = self.upconv1(self.decoder2(dec2))
torch.cuda.empty_cache()
return self.conv,self.decoder1,dec1
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(
name + "conv1",
nn.Conv3d(
in_channels=in_channels,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm1", nn.BatchNorm3d(num_features=features)),
(name + "relu1", nn.LeakyReLU(inplace=True)),
(
name + "conv2",
nn.Conv3d(
in_channels=features,
out_channels=features,
kernel_size=3,
padding=1,
bias=False,
),
),
(name + "norm2", nn.BatchNorm3d(num_features=features)),
(name + "relu2", nn.LeakyReLU(inplace=True)),
]
)
)```
And main code
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
print(device)
unet = UNet3D(config.INPUT_CHANNELS,config.OUTPUT_CHANNELS)
unet.to(device)
dsc_loss = losses2()
unet.half()
best_validation_dsc = 0.0
optimizer = optim.Adam(unet.parameters(), lr=config.LR)
loss_train = []
loss_valid = []
step = 0
for epoch in range(config.EPOCHS):
unet.train()
for idc,loader in enumerate(loader_train):
optimizer.zero_grad()
step+=1
x,y_true=loader
x,y_true=x.to(device),y_true.to(device)
x=x.type(torch.float16)
y_true=y_true.type(torch.float16)
unet.half()
dr,do,y_pred=unet(x)
m=nn.ReLU()
y_pred=m(dr((y_pred)))
plt.imshow(y_true[0,:,:,75].detach().cpu().type(torch.float32))
plt.imshow(y_pred[0,6,:,:,75].detach().cpu().type(torch.float32))
plt.imshow(y_pred[0,0,:,:,75].detach().cpu().type(torch.float32),cmap='gray')
loss=dsc_loss.dice_loss(y_true,y_pred)
print("{} epoch {} iteration and loss is {}".format(epoch+1,idc+1,loss.item()))
loss_train.append(loss.item())
loss.backward()
#print(unet.encoder1.enc1conv1.weight.max())
unet.float()
optimizer.step()
print(unet.decoder3.dec3conv1.weight.grad.mean())
torch.save({
'epoch': epoch,
'model_state_dict': unet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'model{}.pt'.format(epoch))