Model seems to be learning fine, but performs badly when not doing a loss step

I’m baffled on this one.

I am training with this core loop (batch size is 1):

for image_data_s, demographics_s, actual_age_s in train_loader:
      image_data_s = image_data_s.to(device)
      demographics_s = demographics_s.to(device)
      actual_age_s = actual_age_s.to(device)

      predicted_ages = model(image_data_s, demographics_s)
      predicted_ages = predicted_ages.squeeze(1)

      X.append(actual_age_s.clone().detach().item())
      Y.append(predicted_ages.clone().detach().item())

      loss = criterion(predicted_ages, actual_age_s)
      
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

As you can see, I save the training prediction and target so I can plot them, to see how training is going.

By the second epoch, the results during training are already looking ok (and get much better by epoch 10):

However, when I try to use this trained model - just on the training data to be clear - without the backward step I get this output:

I am completely stumped. I get this by running the exact same code, directly after training, but by commenting out these lines:

 optimizer.zero_grad()
 loss.backward()
 optimizer.step()

What am I missing?

Could you post the model definition or (better) a minimal and executable code snippet to reproduce the issue, please?

Hiya, thanks for the respone. It’s going to be difficult to get snippet that runs, so let’s see if you can see anything off with the model def:



n_input_channels = 1

n_latent_channels = 128

total_down_conv_channels = n_latent_channels

first_down_res = True 

first_down_res_stride = (2 if first_down_res else 1)

second_down_res = True

second_down_res_stride = (2 if second_down_res else 1)

#

kernel_size = 3
padding = 1

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        # [200, 256, 256, x]

        self.el0a = nn.Conv3d(in_channels=n_input_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=2, bias=False)   
        self.el0b = nn.ReLU(inplace=True)
        self.el0b_bn = nn.BatchNorm3d(total_down_conv_channels)
    
        # [100, 128, 128, 64]

        self.el1 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=first_down_res_stride, bias=False)   
        self.el2 = nn.ReLU(inplace=True)
        self.el2_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.max_pool_1 = nn.MaxPool3d(kernel_size=[2,2,2])

        # [50, 64, 64, 64]

        self.el3 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el4 = nn.ReLU(inplace=True)
        self.el4_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.el3a = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el4a = nn.ReLU(inplace=True)
        self.el4a_bn = nn.BatchNorm3d(total_down_conv_channels)


        self.mid_conv_dropout = nn.Dropout(p=0.3)

        self.max_pool_2 = nn.MaxPool3d(kernel_size=[2,2,2])

        # [50, 64, 64, 64]

        self.el5 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=second_down_res_stride, bias=False)   
        self.el6 = nn.ReLU(inplace=True)
        self.el6_bn = nn.BatchNorm3d(total_down_conv_channels)

        # [25, 32, 32, 64]

        self.el7 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el8 = nn.ReLU(inplace=True)
        self.el8_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.el7a = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el8a = nn.ReLU(inplace=True)
        self.el8a_bn = nn.BatchNorm3d(total_down_conv_channels)

        # [25, 32, 32, 64]

        self.el9 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=1, padding=0, stride=1, bias=False)

        self.max_pool_3 = nn.MaxPool3d(kernel_size=[2,2,2])

        self.el9_bn = nn.BatchNorm3d(total_down_conv_channels)

        #############

        self.global_pool = nn.AdaptiveAvgPool3d(output_size=1)

        self.fc1 = nn.Linear(n_latent_channels + demographics_length, 256).double()
        self.fc1_relu = nn.ReLU()

        self.mid_fc_dropout = nn.Dropout(p=0.3)

        self.fc2 = nn.Linear(256, 1).double()
        self.fc2_relu = nn.ReLU()

        #############

      

        
    def forward(self, image_data, demographics):
        x = self.el0a(image_data)
        x = self.el0b(x)
        x = self.el0b_bn(x)

        x = self.el1(x)
        x = self.el2(x)
        x = self.max_pool_1(x)
        x = self.el2_bn(x)
        x = self.el3(x)
        x = self.el4(x)
        x = self.max_pool_2(x)
        x = self.el4_bn(x)

        x = self.el3a(x)
        x = self.el4a(x)
        x = self.el4a_bn(x)

        x = self.mid_conv_dropout(x)

        x = self.el5(x)
        x = self.el6(x)
        x = self.el6_bn(x)

        x = self.el7(x)
        x = self.el8(x)
        x = self.el8_bn(x)

        x = self.el7a(x)
        x = self.el8a(x)
        x = self.el8a_bn(x)

        x = self.el9(x)
        x = self.max_pool_3(x)
        x = self.el9_bn(x)

        #


        x = self.global_pool(x) # this produces one per channel

        x = x.squeeze() # this is now shape [n_channels]
        x = torch.reshape(x, (1, total_down_conv_channels))

        x = torch.cat([x,demographics], axis=1)

        x = self.fc1(x)
        x = self.fc1_relu(x)

        x = self.mid_fc_dropout(x)

        x = self.fc2(x)
        x = self.fc2_relu(x)

        #

        predicted_age = x

        return predicted_age

Your code is missing a lot of definitions, so I won’t be able to execute and experiment with it.
In any case, these lines look wrong:

        x = x.squeeze() # this is now shape [n_channels]
        x = torch.reshape(x, (1, total_down_conv_channels))

as it seems you are forcing the batch size to be 1.
Is this always the case or could some unwanted broadcasting be executed somewhere?

Hm. I was under the mistaken impression that by the time data was going through Module::forward it was one sample.

(As it happens though, the batch size is 1 at all times)

If that line looks wrong though, we could try to fix it - would x.squeeze(1) be right? The output of self.global_pool should be in shape [128,1] if I’m understanding AdaptiveAvgPool3d correctly, and I want it to be of shape [128] then concatinated with the demographics.

(I’ve added some definitions to the code above)

I’ve removed the dropout layers for testing (and reduced the number of training samples to speed up by debugging, and this is interesting):

Epoch 2 during training:

Without training after Epoch 2 on same data:

I bet this is something silly but I can’t see it

Did you keep the model in .train() mode? If so, note that the batchnorm stats would still be updated. However, if you never switch to .eval() these running stats also won’t be used, so it’s still unclear to me what’s causing the issue.
Could you describe your workflow again so that I could try to use random data to reproduce it?
I.e. in particular when .train/eval is called, when the model is updated etc.

I do keep the model in train mode (I’ve tried going to eval first, but that didn’t work.)

I’ve created a version which should work on its own without any edits - thank you for taking a look.



import torch
from torch import nn
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from matplotlib import pyplot as plt
import numpy as np

###############

saved_model_filename = 'brain_age_network.pt'

###

learning_rate = 1e-3

batch_size = 1

n_global_epochs = 2 # going around all of the scans


###

rescaled_image_resolution = [200,256,256]

#

dropout = False

n_input_channels = 1

n_latent_channels = 128

total_down_conv_channels = n_latent_channels

first_down_res = True 

first_down_res_stride = (2 if first_down_res else 1)

second_down_res = True

second_down_res_stride = (2 if second_down_res else 1)

#

kernel_size = 3
padding = 1


###

class RandDataset(Dataset):
    def __init__(self, seed):
        np.random.seed(seed)

        self.image_data = np.random.rand(1, rescaled_image_resolution[0],rescaled_image_resolution[1],rescaled_image_resolution[2])

        self.age = np.random.random()

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.image_data, self.age


#

class BNModel(nn.Module):
    def __init__(self):
        super().__init__()

        # [200, 256, 256, x]

        self.el0a = nn.Conv3d(in_channels=n_input_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=2, bias=False)   
        self.el0b = nn.ReLU(inplace=True)
        self.el0b_bn = nn.BatchNorm3d(total_down_conv_channels)
    
        # [100, 128, 128, 64]

        self.el1 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=first_down_res_stride, bias=False)   
        self.el2 = nn.ReLU(inplace=True)
        self.el2_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.max_pool_1 = nn.MaxPool3d(kernel_size=[2,2,2])

        # [50, 64, 64, 64]

        self.el3 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el4 = nn.ReLU(inplace=True)
        self.el4_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.el3a = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el4a = nn.ReLU(inplace=True)
        self.el4a_bn = nn.BatchNorm3d(total_down_conv_channels)

        if dropout:
            self.mid_conv_dropout = nn.Dropout(p=0.3)

        self.max_pool_2 = nn.MaxPool3d(kernel_size=[2,2,2])

        # [50, 64, 64, 64]

        self.el5 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=second_down_res_stride, bias=False)   
        self.el6 = nn.ReLU(inplace=True)
        self.el6_bn = nn.BatchNorm3d(total_down_conv_channels)

        # [25, 32, 32, 64]

        self.el7 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el8 = nn.ReLU(inplace=True)
        self.el8_bn = nn.BatchNorm3d(total_down_conv_channels)

        self.el7a = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=kernel_size, padding=padding, stride=1, bias=False)
        self.el8a = nn.ReLU(inplace=True)
        self.el8a_bn = nn.BatchNorm3d(total_down_conv_channels)

        # [25, 32, 32, 64]

        self.el9 = nn.Conv3d(in_channels=total_down_conv_channels,out_channels=total_down_conv_channels,kernel_size=1, padding=0, stride=1, bias=False)

        self.max_pool_3 = nn.MaxPool3d(kernel_size=[2,2,2])

        self.el9_bn = nn.BatchNorm3d(total_down_conv_channels)

        #############

        self.global_pool = nn.AdaptiveAvgPool3d(output_size=1)

        self.fc1 = nn.Linear(n_latent_channels, 256)

        self.fc1_relu = nn.ReLU()

        if dropout:
            self.mid_fc_dropout = nn.Dropout(p=0.3)

        self.fc2 = nn.Linear(256, 1)

        self.fc2_relu = nn.ReLU()

        #############

      

        
    def forward(self, image_data):
        x = self.el0a(image_data)
        x = self.el0b(x)
        x = self.el0b_bn(x)

        x = self.el1(x)
        x = self.el2(x)
        x = self.max_pool_1(x)
        x = self.el2_bn(x)
        x = self.el3(x)
        x = self.el4(x)
        x = self.max_pool_2(x)
        x = self.el4_bn(x)

        x = self.el3a(x)
        x = self.el4a(x)
        x = self.el4a_bn(x)

        if dropout:
            x = self.mid_conv_dropout(x)

        x = self.el5(x)
        x = self.el6(x)
        x = self.el6_bn(x)

        x = self.el7(x)
        x = self.el8(x)
        x = self.el8_bn(x)

        x = self.el7a(x)
        x = self.el8a(x)
        x = self.el8a_bn(x)

        x = self.el9(x)
        x = self.max_pool_3(x)
        x = self.el9_bn(x)

        #

        x = self.global_pool(x) # this produces one per channel

        x = torch.reshape(x, (1, total_down_conv_channels))
        

        x = self.fc1(x)
        x = self.fc1_relu(x)

        if dropout:
            x = self.mid_fc_dropout(x)

        x = self.fc2(x)
        x = self.fc2_relu(x)

        #

        predicted_age = x

        return predicted_age


###



def train_model():
    # set random seeds
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    # setup device cuda vs. cpu
    cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if cuda else "cpu")

    assert cuda, "we require cuda for the brain network"

    model = BNModel().to(device)

    #

    training_seeds = [np.random.randint(100000,999999) for i in range(50)]

    # Setting the optimiser
    
    try:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=learning_rate,
            fused=True, # this does more of the work on the GPU
        )
    except Exception as e:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=learning_rate,
        )

    #

    criterion = torch.nn.MSELoss(reduction='sum')
    
    X = []
    Y = []

   #

    from matplotlib.backends.backend_pdf import PdfPages

    pdf = PdfPages("brain_age_model_training.pdf")

    #


    for global_epoch in range(1, n_global_epochs+1):
        global_epoch_train_loss = 0
        n_global_epoch_trainings = 0

        X = []
        Y = []

        for seed in training_seeds:
            try:
                dataset = RandDataset(seed)
            except Exception as e:
                continue

            train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

            # train for one epoch
            model.train()
            
            for image_data_s, actual_age_s in train_loader:
                image_data_s = image_data_s.float().to(device)
                actual_age_s = actual_age_s.float().to(device)

                

                # ===================forward=====================
                predicted_ages = model(image_data_s)

                predicted_ages = predicted_ages.squeeze(1)

                X.append(actual_age_s.clone().detach().item())
                Y.append(predicted_ages.clone().detach().item())

                loss = criterion(predicted_ages, actual_age_s)
                
                local_loss = loss.clone().detach()
                if global_epoch_train_loss is None: # keep it on the gpu to avoid unneccesary cpu/gpu syncs
                    global_epoch_train_loss = local_loss
                else:
                    global_epoch_train_loss += local_loss
                n_global_epoch_trainings += 1


                # ===================backward====================
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

          
    

        global_epoch_train_loss = global_epoch_train_loss.item()
        average_loss_per_scan = global_epoch_train_loss / n_global_epoch_trainings
        print(f"Global epoch average scan loss: {average_loss_per_scan}")

        # save a graph to a pdf to show our predictions vs reality

        fig, ax = plt.subplots()
        ax.plot([0.5,1.0],[0.5,1.0],c="tab:grey")
        ax.scatter(X, Y, s=2.2,c="tab:orange")
        ax.set_xlabel("Actual age")
        ax.set_ylabel("Brain age")
        fig.suptitle(f"Epoch {global_epoch}.  Average loss per subject: {average_loss_per_scan}", fontsize=8)
        pdf.savefig(fig)
        plt.close(fig)




    if True:
        X = []
        Y = []

        #with torch.no_grad():
        if True:
            #model.train()
            #model.eval()

            for seed in training_seeds:
                try:
                    dataset = RandDataset(seed)
                except Exception as e:
                    continue

                train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

                for image_data_s, actual_age_s in train_loader:

                    image_data_s = image_data_s.float().to(device)
                    actual_age_s = actual_age_s.float().to(device)

                    # ===================forward=====================
                    predicted_ages = model(image_data_s)
                    predicted_ages = predicted_ages.squeeze(1)

                    X.append(actual_age_s.clone().detach().item())
                    Y.append(predicted_ages.clone().detach().item())

                    loss = criterion(predicted_ages, actual_age_s)


        #

        fig, ax = plt.subplots()
        ax.plot([0.5,1.0],[0.5,1.0],c="tab:grey")
        ax.scatter(X, Y, s=2.2,c="tab:orange")
        ax.set_xlabel("Actual age")
        ax.set_ylabel("Brain age")
        fig.suptitle(f"POST TRAINING RESULTS NON-RELOADED SCANS", fontsize=8)
        pdf.savefig(fig)
        plt.close(fig)
        
    

    #

    # save the state dict of the model
    torch.save(model.state_dict(), saved_model_filename)

    #

    pdf.close()
    plt.close('all')


###


train_model()

This gives the same fundamental problem to me - during training we get at least some learning (although it’s more random with fully random data):

And yet when not doing the backwards step (the final graph) - we just get flat output:

Any chance you could take a look at this @ptrblck ? I’m still stumped.

I think the lack of optimizing the parameters in the second run lets the model collapse.
If I remove the optimizer.step() call from the first loop, I see:
image

Additionally, the last nn.ReLU layer might also cause trouble as it would clip the output and could show the bias only.

Thank you for taking a look.

That final relu does indeed look like a mistake, and I’ve removed it.

But I don’t understand the idea of “model collapse”. All that second loop is doing is testing the model, not training it - why is it possible for it to have a constant output?

You could check if only the bias of the last layer is used and all previous activations are clipped to zero.
Also, this effect is not specific to the second run as it’s already visible in the first loop if no parameters are updated.

I think the culprit might be the AdaptiveAvgPool3d layer. Is there an alternative just for testing to get the [25, 32, 32, 64] shape 3D data into a shortish linear array?

Based on the definition of global_pool = nn.AdaptiveAvgPool3d(output_size=1) it seems you want to reduce the depth, height, and width to 1 via a mean operation.
An alternative would be to use max or sum, but I don’t know if this would help or if increasing the output size might be beneficial in your case.

Well, increasing the output size of the AdaptiveAvgPool3d layer has stopped the output being flat, but it still isn’t the same as it is during the training runs. I guess I think the problem is in the reshape call, but I don’t know how else I’m supposed to get data which is in shape, say

[1,128,4,4,4]

into

[1,128 * 4 * 4 * 4]

So that it can be put through a linear layer.

And I still have no idea why the behavour changes when you aren’t doing a backwards step. Do yo understand why?

I guess because your model is always saturating the outputs which explains also the static output during training if no parameters are updated.
If you call optimizer.step() in each iteration the output will change slightly and the training loop looks as if the outputs are improving while you might just have shifted the static output value. You could test it by updating once and checking if the next outputs would all have the same but new value again.
In the end I don’t think your model is training properly as not even the training samples can overfit the target, so I would recommend focusing on this issue first.

1 Like

Thank you as always for your help, I think that gets to the heart of the matter.