Loss doesn't decrease, reuqires_grad = True

trying to do transfer learning and fine-tuning of a few layer of one model to another model with shared architecture.

after initializing the new model and loading the shared weights i start training as usual and no training occur.

i look at a conv3d layer object and see in it’s attributes (only copying things i suspect are relevant):
training = True
weight.grad_fn = None
weight.requires_grad = True
weight.grad.grad = None
weight.grad.grad_fn = None
weight.grad.requires_grad = False

i can’t find any official doc about the grad object inside the weight object, and why do they have separate require grad attributes and what do they mean.

can anyone tell me why my network doesn’t train?

Hi,

You want to make sure that the parameters you give to the optimizers are the right ones as well. In particular, you should create the optimizer after you loaded the shared weights.

that’s interesting, can you explain why is that crucial?
why must i do optimizer = Adam(model.parameters after loading the pretrained weights)

anyway, this is not the problem in my case, since i also don’t see training when trying to train the network without loading weights.

i saw some post about a similar issue that was solved by changing a bunch of leaky relus to relus. i suspect this is a problem i’m facing.

the original net works with leakyRelu and then i have a new net composed of the first half of the first net followed by a flatten layer, a linear layer and a Relu. the output of this relu is one neuron and the loss is MSE.

it’s always returning 0. so that’s why i think there is no learning.

but there is also bias so why doesn’t the net simply learn to add bias and to get a positive outcome?

Hi,

This would be a problem as the optimizer might be created with the “old” parameters. And so it will try to update the old parameters and not the new ones that get the gradients.

But if your parameters actually have .grad fields that get populated during the backward. But they are all 0s then it is something to do with your network indeed.
You want to make sure that no part of your net/loss will generate an all 0 gradient. This can happen if you have special layers or bad initialization (like something that is always negative just before a relu).

i tried another sanity check, just flattened the input of the entire net (a 1,64,64,32 tensor) and inserted into a linear layer that ends at one neuron that goes into MSE loss.

still same behavior of static loss. important to note that i’m giving the same input everytime.

so maybe it means it’s not the net itself after all?

Could you share a small piece of code I could run that reproduces that?

thank you for trying to help!

this is the only thing i can send without going too deep into details. maybe you can find something wrong with this architecture as it is.

the input x is of shape batch 2, channel 1, width 64 height 64 depth 32
the target is of shape batch 2, dim 1

import torch.nn as nn
import torch
from torchsummary import summary
from lib.medzoo.BaseModelClass import BaseModel
import matplotlib.pyplot as plt



class half_Unet_age(BaseModel):

    """
    Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
    with changes to fit volume inpainting
    """

    def __init__(self, in_channels=1, n_classes=1, dim=(64,64,32), base_n_filter=8):
        super(half_Unet_age, self).__init__()
        self.n_classes = n_classes # no classification, just regression
        self.in_channels = in_channels
        self.n_subjects = args.subject_num
        self.dim = dim
        self.base_n_filter = base_n_filter
        self.lrelu = nn.LeakyReLU()
        self.dropout3d = nn.Dropout3d(p=0.6)
        self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                     bias=False)
        # if learning personal embedding, initialize an embedding layer and add a channel to the first conv layer.
     
            self.conv3d_c1_1 = nn.Conv3d(self.in_channels+1, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                         bias=False)
        else:
            self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                         bias=False)
        self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                     bias=False)
        self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
        self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)

        self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2)
        self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2)

        self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4)
        self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4)

        self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8)
        self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8)

        self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16)
        self.predict_age = self.decode_age(4096)

    def decode_age(self,feat_in):
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(feat_in,1),
            nn.ReLU())

    def conv_norm_lrelu(self, feat_in, feat_out):
        return nn.Sequential(
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(feat_out),
            nn.LeakyReLU())

    def norm_lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.InstanceNorm3d(feat_in),
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def forward(self, x):

        #  Level 1 context pathway

        out = self.conv3d_c1_1(x)
        residual_1 = out
        out = self.lrelu(out)
        out = self.conv3d_c1_2(out)
        out = self.dropout3d(out)
        out = self.lrelu_conv_c1(out)
        # Element Wise Summation
        out += residual_1
        out = self.inorm3d_c1(out)
        out = self.lrelu(out)
        # Level 2 context pathway
        out = self.conv3d_c2(out)
        residual_2 = out
        out = self.norm_lrelu_conv_c2(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c2(out)
        out += residual_2
        out = self.inorm3d_c2(out)
        out = self.lrelu(out)
        # Level 3 context pathway
        out = self.conv3d_c3(out)
        residual_3 = out
        out = self.norm_lrelu_conv_c3(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c3(out)
        out += residual_3
        out = self.inorm3d_c3(out)
        out = self.lrelu(out)
        # Level 4 context pathway
        out = self.conv3d_c4(out)
        residual_4 = out
        out = self.norm_lrelu_conv_c4(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c4(out)
        out += residual_4
        out = self.inorm3d_c4(out)
        out = self.lrelu(out)
        # Level 5
        out = self.conv3d_c5(out)
        residual_5 = out
        out = self.norm_lrelu_conv_c5(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c5(out)
        out += residual_5

        return self.predict_age(x)

    def load_partial_state_dict(self, state_dict):
        print('loading parameters onto new model...')
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                print('notice: {} is not part of new model and was not loaded.'.format(name))
                continue
            param = param.data
            own_state[name].copy_(param)



    def test(self,device='cpu'):

        input_tensor = torch.rand(1, 2, 32, 32, 32)
        ideal_out = torch.rand(1, self.n_classes, 32, 32, 32)
        out = self.forward(input_tensor)
        assert ideal_out.shape == out.shape
        summary(self.to(torch.device(device)), (2, 32, 32, 32),device='cpu')
        # import torchsummaryX
        # torchsummaryX.summary(self, input_tensor.to(device))
        print("Unet3D test is complete")

the base model class

"""
Implementation of BaseModel taken and modified from here
https://github.com/kwotsin/mimicry/blob/master/torch_mimicry/nets/basemodel/basemodel.py
"""

import os
from abc import ABC, abstractmethod
import torch
import torch.nn as nn


class BaseModel(nn.Module, ABC):
    r"""
    BaseModel with basic functionalities for checkpointing and restoration.
    """

    def __init__(self):
        super().__init__()
        self.best_loss = 1000000

    @abstractmethod
    def forward(self, x):
        pass

    @abstractmethod
    def test(self):
        """
        To be implemented by the subclass so that
        models can perform a forward propagation
        :return:
        """
        pass

    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        r"""
        Restores checkpoint from a pth file and restores optimizer state.

        Args:
            ckpt_file (str): A PyTorch pth file containing model weights.
            optimizer (Optimizer): A vanilla optimizer to have its state restored from.

        Returns:
            int: Global step variable where the model was last checkpointed.
        """
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        # TODO return optimizer?????
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['epoch']

    def save_checkpoint(self,
                        directory,
                        epoch, loss,
                        optimizer=None,
                        name=None,phase=None):
        r"""
        Saves checkpoint at a certain global step during training. Optimizer state
        is also saved together.

        Args:
            directory (str): Path to save checkpoint to.
            epoch (int): The training. epoch
            optimizer (Optimizer): Optimizer state to be saved concurrently.
            name (str): The name to save the checkpoint file as.

        Returns:
            None
        """
        if phase is not None:
            name = name+'_age_prediction'
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict':
                self.state_dict(),
            'optimizer_state_dict':
                optimizer.state_dict() if optimizer is not None else None,
            'epoch':
                epoch
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_epoch.pth".format(
                os.path.basename(directory),  # netD or netG
                'last')

        torch.save(ckpt_dict, os.path.join(directory, name))
        if self.best_loss > loss:
            self.best_loss = loss
            name = "{}_BEST.pth".format(
                os.path.basename(directory))
            torch.save(ckpt_dict, os.path.join(directory, name))

    def count_params(self):
        r"""
        Computes the number of parameters in this model.

        Args: None

        Returns:
            int: Total number of weight parameters for this model.
            int: Total number of trainable parameters for this model.

        """
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters()
                                   if p.requires_grad)

        return num_total_params, num_trainable_params

    def inference(self, input_tensor):
        self.eval()
        with torch.no_grad():
            output = self.forward(input_tensor)
            if isinstance(output, tuple):
                output = output[0]
            return output.cpu().detach()

this is now enough to fully reproduce the behavior (together with the upper replies)

import torch
import lib.medzoo as medzoo
import argparse
from torch.nn import MSELoss


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--title', default='first attempt at end to end age prediction without prior training')
    parser.add_argument('--batchSz', type=int, default=2)
    parser.add_argument('--dataset_name', type=str, default="fmri_hcp")
    parser.add_argument('--dim', nargs="+", type=int, default=(64, 64, 32))
    parser.add_argument('--nEpochs', type=int, default=1)
    parser.add_argument('--nEpochs_age_prediction', type=int, default=100)
    parser.add_argument('--augmentation', action='store_true', default=False)
    parser.add_argument('--inChannels', type=int, default=1)
    parser.add_argument('--inModalities', type=int, default=1)
    parser.add_argument('--split', default=0.8, type=float, help='Select percentage of training data(default: 0.8)')
    parser.add_argument('--lr', default=1e-2, type=float,
                        help='learning rate (default: 5e-3)')
    parser.add_argument('--lr_age_prediction', default=1e-3, type=float,
                        help='learning rate (default: 5e-3)')
    parser.add_argument('--lrsteps', default=[1,4,12,30,70,120,180,270,350,500,700])
    parser.add_argument('--inpaint_location', default=[40,30,17])
    parser.add_argument('--overfit', default=False)
    parser.add_argument('--norm_method', default='global') # pre_voxel
    parser.add_argument('--terminal_show_freq', default=1000)
    parser.add_argument('--classes', type=int, default=1)
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--context_0', default=False)
    parser.add_argument('--loadData', default=False)
    parser.add_argument('--workers', default=6)
    parser.add_argument('--time_step', default=1)
    parser.add_argument('--subject_num', default=1)
    parser.add_argument('--subject_embedding', default=False)
    parser.add_argument('--predict_age', default=True)
    parser.add_argument('--plot_log_scale', default=True)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='UNET3D_gony',
                        choices=('VNET', 'VNET2', 'UNET3D', 'UNET3D_gony','DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
    parser.add_argument('--opt', type=str, default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--log_dir', type=str,
                        default='../runs/')

    args = parser.parse_args()

    return args

args = get_arguments()
input = torch.rand((2,1,64,64,32))
subj = torch.tensor([1,2])
target = torch.tensor([14,5])

model = half_Unet_age()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
criterion = MSELoss()
if args.cuda:
    input = input.cuda()
    subj = subj.cuda()
    target = target.cuda()
    target = target.to(dtype=torch.float32)
    model = model.cuda()
model.train()
for batch_idx in range(50):
    optimizer.zero_grad()
    input.requires_grad = True
    output = model(input,subj)
    loss = criterion(output, target)
    loss.backward()
    train_loss = loss.item()
    optimizer.step()
    print('loss: ',loss.detach().item())

This is not really a “small piece of code” haha

The way I would go about reducing this is:

  • First check that you do get .grad field populated (you cal del p.grad on a parameter to remove it and check after the call to .backward() that it is back again)
  • Check that this .grad field is full of 0s.
  • Start adding hooks in your forward function like out.register_hook(print) that will print the gradient for the Tensor out. You should have non-zero gradients at the complete end of the loss (it will be 1), then you can check that the network’s output does get non-zero gradient as well and then you can move up the network to see when it becomes 0 to see what is causing this

We try to solve a similar issue here, I guess.
Loss per batch is not decreasing, need help! - PyTorch Forums

this is getting strange by the minute:

  1. as expected the loss at the loss layer is 1.
  2. prior to the output layer are - flatten layer --> fully connected (in-4096,out-1) --> Relu (and then MSE loss)
  3. requesting out.register_hook(print) at these three layers yield -
tensor([[-20.],
        [-20.]], device='cuda:0')
tensor([[0.],
        [0.]], device='cuda:0')
tensor([[-0., -0., 0.,  ..., -0., -0., -0.],
        [-0., -0., 0.,  ..., -0., -0., -0.]], device='cuda:0')
loss:  500.0
tensor([[-20.0000],
        [-19.9298]], device='cuda:0')
tensor([[  0.0000],
        [-19.9298]], device='cuda:0')
tensor([[-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0891,  0.2238, -0.0680,  ...,  0.2725,  0.2421,  0.2217]],
       device='cuda:0')
loss:  498.59820556640625
tensor([[-5.2828],
        [ 9.6693]], device='cuda:0')
tensor([[-5.2828],
        [ 9.6693]], device='cuda:0')
tensor([[-0.0157,  0.0200, -0.0573,  ...,  0.1115,  0.0249,  0.0194],
        [ 0.0287, -0.0366,  0.1050,  ..., -0.2042, -0.0455, -0.0356]],
       device='cuda:0')
loss:  160.7019805908203
tensor([[-20.],
        [-20.]], device='cuda:0')
tensor([[0.],
        [0.]], device='cuda:0')
tensor([[0., 0., 0.,  ..., -0., 0., 0.],
        [0., 0., 0.,  ..., -0., 0., 0.]], device='cuda:0')
loss:  500.0
tensor([[-20.],
        [-20.]], device='cuda:0')
tensor([[0.],
        [0.]], device='cuda:0')
tensor([[0., 0., 0.,  ..., -0., 0., 0.],
        [0., 0., 0.,  ..., -0., 0., 0.]], device='cuda:0')
loss:  500.0
tensor([[-20.],
        [-20.]], device='cuda:0')
tensor([[0.],
        [0.]], device='cuda:0')
tensor([[0., 0., 0.,  ..., -0., 0., 0.],
        [0., 0., 0.,  ..., -0., 0., 0.]], device='cuda:0')
loss:  500.0

what doe’s this mean?

also, i tried another approach and replaced my half_net with a simple 3 layer MLP.
fed one sample and got learning that stopped after 300 epochs without converging to the correct answer.
loss stabilized on a fixed nubmer and refused to move after those 300 epochs.

how is this all related?

ok! seems like the entire drama was solved by changing the Relu at the end of the network to LeakyRelu.

i don’t fully understand what happened, maybe albanD can comment?

my intuition is that the entire network is composed of LeakyRelus and norm layers, so it must contain a lot of negative activations, such that it happens to add up as negative when entering relu - thus outputting a zero activation in the last neuron.

having a zero activation at the last neuron, according to my understanding, shouldn’t prevent the backprop from occurring normally. calculating dLoss/dLastNeuron doesn’t concern this zero value.

BUT, calculating the change in loss with respect to the fully connected layer, it might be trapped thanks to relu and a quite negative sum.

what i mean is that a small change in the activation of one of the neurons summing up to the relu, would not change the outcome of the relu at all because under a small change it would stay negative. and leakyRelu solved this issue.

can anyone comment on this?

thanks for helping everyone!

That most likely means that your input + weights were generating an output that was all negative. And so you were in the flat part of the relu for all elements and so all the gradients were 0.
The leaky relu has no flat region hence avoiding this issue.