Zero gradients at Training

Hi all,

I tried to implement U-Net 3D on MRI Dataset with a custom shape and ı modified the strides and padding to match the sizes at skip connection processes. But when I train the network all layers have zero-grads but last. As a result of that, my loss stuck at a level that random parameters can provide. I tried to change the activation functions, shapes, and many maaaany other things. I am exhausted. Can somebody help ?

UNET has [1,4,240,240,155] input and [1,7,240,240,155] output. I use generalised dice loss with softmax. Model depth is 4. Because of the size of the process I use half and fp16 on model and input,output.

Here are my codes

For loss


    def dice_loss(self,true, logits, eps=1e-7):
        """Computes the Sørensen–Dice loss.
        Note that PyTorch optimizers minimize a loss. In this
        case, we would like to maximize the dice loss so we
        return the negated dice loss.
        Args:
            true: a tensor of shape [B, 1, H, W].
            logits: a tensor of shape [B, C, H, W]. Corresponds to
                the raw output or logits of the model.
            eps: added to the denominator for numerical stability.
        Returns:
            dice_loss: the Sørensen–Dice loss.
        """
        num_classes = logits.shape[1]
        if num_classes == 1:
            true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
            true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
            true_1_hot_f = true_1_hot[:, 0:1, :, :]
            true_1_hot_s = true_1_hot[:, 1:2, :, :]
            true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
            pos_prob = torch.sigmoid(logits)
            neg_prob = 1 - pos_prob
            probas = torch.cat([pos_prob, neg_prob], dim=1)
        else:
            true_1_hot = torch.eye(num_classes)[true.long()]
            true_1_hot = true_1_hot.permute(0, 4, 1, 2 , 3).float()
            probas = F.softmax(logits,dim=1)

        true_1_hot = true_1_hot.type(logits.type())
        dims = (0,) + tuple(range(2, true.ndimension()))
        intersection = torch.sum(probas * true_1_hot, dims)
        cardinality = torch.sum(probas + true_1_hot, dims)
        dice_loss = (2. * intersection / (cardinality + eps)).mean()
        return (1 - dice_loss)```


Model 

class UNet2D(nn.Module):

def __init__(self,inputChannels,outputChannels,init_features=32):
    super(UNet2D, self).__init__()

    features = init_features
    self.encoder1 = UNet2D._block(inputChannels, features, name="enc1")
    self.pool1 = nn.MaxPool3d(kernel_size=3, stride=2)
    self.encoder2 = UNet2D._block(features, features * 2, name="enc2")
    self.pool2 = nn.MaxPool3d(kernel_size=3, stride=2)
    self.encoder3 = UNet2D._block(features * 2, features * 4, name="enc3")
    self.pool3 = nn.MaxPool3d(kernel_size=3, stride=2)
    self.encoder4 = UNet2D._block(features * 4, features * 8, name="enc4")
    self.pool4 = nn.MaxPool3d(kernel_size=3, stride=2)

    self.bottleneck = UNet2D._block(features * 8, features * 16, name="bottleneck")

    self.upconv4 = nn.ConvTranspose3d(
        features * 16, features * 8, kernel_size=(3,3,4), stride=2
    )
    self.decoder4 = UNet2D._block((features * 8)*2 , features * 8, name="dec4")
    self.upconv3 = nn.ConvTranspose3d(
        features * 8, features * 4, kernel_size=(3,3,4), stride=2
    )
    self.decoder3 = UNet2D._block((features * 4)*2 , features * 4, name="dec3")
    self.upconv2 = nn.ConvTranspose3d(
        features * 4, features * 2, kernel_size=(3,3,3), stride=2
    )
    self.decoder2 = UNet2D._block((features * 2)*2 , features * 2, name="dec2")
    self.upconv1 = nn.ConvTranspose3d(
        features * 2, features, kernel_size=(4,4,3), stride=2
    )
    self.decoder1 = UNet2D._block(features, features, name="dec1")

    self.conv = nn.Conv3d(
        in_channels=features, out_channels=outputChannels, kernel_size=1
    )

def forward(self, x):
    device='cuda:0'
    enc1 = self.encoder1(x)
    enc2 = self.encoder2(self.pool1(enc1))

    enc3 = self.encoder3(self.pool2(enc2))
  

    enc4 = self.encoder4(self.pool3(enc3))
    enc5 = self.bottleneck(self.pool4(enc4))
    
    dec4 = self.upconv4(enc5)
    
    dec4 = torch.cat((dec4, enc4.to(device)), dim=1)
    dec4 = self.decoder4(dec4)
    dec3 = self.upconv3(dec4)
   
    dec3 = torch.cat((dec3, enc3.to(device)), dim=1)
    dec3 = self.decoder3(dec3)
    dec2 = self.upconv2(dec3)
    
    dec2 = torch.cat((dec2, enc2.to(device)), dim=1)
    
    dec1 = self.upconv1(self.decoder2(dec2))
   
    torch.cuda.empty_cache()
  
    
    return self.conv,self.decoder1,dec1
    
@staticmethod
def _block(in_channels, features, name):
    return nn.Sequential(
        OrderedDict(
            [
                (
                    name + "conv1",
                    nn.Conv3d(
                        in_channels=in_channels,
                        out_channels=features,
                        kernel_size=3,
                        padding=1,
                        bias=False,
                    ),
                ),
                (name + "norm1", nn.BatchNorm3d(num_features=features)),
                (name + "relu1", nn.LeakyReLU(inplace=True)),
                (
                    name + "conv2",
                    nn.Conv3d(
                        in_channels=features,
                        out_channels=features,
                        kernel_size=3,
                        padding=1,
                        bias=False,
                    ),
                ),
                (name + "norm2", nn.BatchNorm3d(num_features=features)),
                (name + "relu2", nn.LeakyReLU(inplace=True)),
            ]
        )
    )```

And main code


device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
print(device)

unet = UNet3D(config.INPUT_CHANNELS,config.OUTPUT_CHANNELS)
unet.to(device)

dsc_loss = losses2()
unet.half()
best_validation_dsc = 0.0
optimizer = optim.Adam(unet.parameters(), lr=config.LR)

loss_train = []
loss_valid = []

step = 0

for epoch in range(config.EPOCHS):
    unet.train()

    
    for idc,loader in enumerate(loader_train):
        optimizer.zero_grad()

        step+=1
        x,y_true=loader
        x,y_true=x.to(device),y_true.to(device)
        x=x.type(torch.float16)
        y_true=y_true.type(torch.float16)
        unet.half()
        dr,do,y_pred=unet(x)
        
        m=nn.ReLU()
        y_pred=m(dr((y_pred)))
        plt.imshow(y_true[0,:,:,75].detach().cpu().type(torch.float32))
        plt.imshow(y_pred[0,6,:,:,75].detach().cpu().type(torch.float32))
        plt.imshow(y_pred[0,0,:,:,75].detach().cpu().type(torch.float32),cmap='gray')
        
        loss=dsc_loss.dice_loss(y_true,y_pred)
        print("{} epoch {} iteration and loss is {}".format(epoch+1,idc+1,loss.item()))
        loss_train.append(loss.item())
        
        
        
        loss.backward()
        #print(unet.encoder1.enc1conv1.weight.max())
        
        
        
        unet.float()
        optimizer.step()
        print(unet.decoder3.dec3conv1.weight.grad.mean())

        
    torch.save({
            'epoch': epoch,
            'model_state_dict': unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'model{}.pt'.format(epoch)) 

Yeah, that’s kind of dangerous, since your model and training might suffer from gradient underflow.
This might happen, if the gradients have a small magnitude and may thus not be representable in float16. You can use the built-in mixed-precision training utility via torch.cuda.amp, which uses a gradient scaling mechanism to prevent the underflow of the gradients as explained here.

You can find more information about it here.

I tried with scale and it shows gradient mean something else than zero. So I think it start to work but it only can perform one iteration then throws memory error. I deleted all the variables and call
torch.cuda.empty_cache() but still same. How can I check which variable takes how much memory on GPU.
Also another question about patchwise UNET. As I said my input is [1,4,240,240,155] if I use patches ,for example [1,4,64,64,64] how can I arrange it for some custom different input? Manually padding zeros ? Or is there anything for that ? Lİke object detectors does ?

Thanks
Abdullah