Training loss not decreasing

Hi,

I am currently trying to implement a 3dUnet {crop-type-mapping/unet3d.py at master · roserustowicz/crop-type-mapping · GitHub} for satellite time series regression.

My input is of the shape [b*t*c*h*w], and the mask is [b*c*h*w]. My training loss (MSE) is not decreasing during my preliminary training results (it stays exactly the same), and i’m just not sure where i’ve gone wrong.

torch.Size([1, 12, 15, 256, 256])
torch.Size([1, 1, 256, 256])

## model
"""
Taken from https://github.com/roserustowicz/crop-type-mapping/
Implementation by the authors of the paper :
"Semantic Segmentation of crop type in Africa: A novel Dataset and analysis of deep learning methods"
R.M. Rustowicz et al.

Slightly modified to support image sequences of varying length in the same batch.
"""

import torch
import torch.nn as nn


def conv_block(in_dim, middle_dim, out_dim):
    model = nn.Sequential(
        nn.Conv3d(in_dim, middle_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(middle_dim),
        nn.LeakyReLU(inplace=True),
        nn.Conv3d(middle_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
        nn.LeakyReLU(inplace=True),
    )
    return model


def center_in(in_dim, out_dim):
    model = nn.Sequential(
        nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
        nn.LeakyReLU(inplace=True))
    return model


def center_out(in_dim, out_dim):
    model = nn.Sequential(
        nn.Conv3d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(in_dim),
        nn.LeakyReLU(inplace=True),
        nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1))
    return model


def up_conv_block(in_dim, out_dim):
    model = nn.Sequential(
        nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm3d(out_dim),
        nn.LeakyReLU(inplace=True),
    )
    return model


class UNet3D(nn.Module):
    def __init__(self, in_channel, n_classes, timesteps=12, dropout=0.5):
        super(UNet3D, self).__init__()
        self.in_channel = in_channel
        self.n_classes = n_classes
        
        feats = 16
        self.en3 = conv_block(in_channel, feats * 4, feats * 4)
        self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.en4 = conv_block(feats * 4, feats * 8, feats * 8)
        self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.center_in = center_in(feats * 8, feats * 16)
        self.center_out = center_out(feats * 16, feats * 8)
        self.dc4 = conv_block(feats * 16, feats * 8, feats * 8)
        self.trans3 = up_conv_block(feats * 8, feats * 4)
        self.dc3 = conv_block(feats * 8, feats * 4, feats * 2)
        self.final = nn.Conv3d(feats * 2, n_classes, kernel_size=3, stride=1, padding=1)
        self.fn = nn.Linear(timesteps, 1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout(p=dropout, inplace=True)



    def forward(self, x):
        x = x.float()
        x = x.permute(0, 2, 1, 3, 4)
        out = x.cuda()
        en3 = self.en3(out)
        pool_3 = self.pool_3(en3)
        en4 = self.en4(pool_3)
        pool_4 = self.pool_4(en4)
        center_in = self.center_in(pool_4)
        center_out = self.center_out(center_in)
        concat4 = torch.cat([center_out,en4],dim=1)
        dc4 = self.dc4(concat4)
        trans3  = self.trans3(dc4)
        concat3 = torch.cat([trans3,en3],dim=1)
        dc3     = self.dc3(concat3)
        final   = self.final(dc3)
        final = final.permute(0,1,3,4,2) # BxCxHxWxT

        shape_num = final.shape[0:4]
        final = final.reshape(-1,final.shape[4])
        final = self.dropout(final)
        final = self.fn(final)
        final = final.reshape(shape_num)
        final = self.logsoftmax(final)

        return final

## initilise the model
model = UNet3D(in_channel=15, n_classes=1)
# initliase loss and optimiser
loss_module = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
criterion = loss_module

# training
def train(model, optimizer, criterion, train_loader, device=None):
    model.train()
    print('Training')
    train_running_loss = 0.0
    train_running_correct = 0
    train_running_RMSE = 0.0 
    counter = 0
    for i, batch in enumerate(train_loader):
        counter += 1
        batch = recursive_todevice(batch, device)
        input, label = batch
        print(input.shape)
        print(label.shape)
        optimizer.zero_grad()
        # forward pass
        outputs = model(input)
        # calculate the loss
        loss = criterion(outputs, label)
        train_running_loss += loss.item()
        rmse = torch.sqrt(loss)
        train_running_RMSE += rmse
        ## Log losses to Neptune vis
        run["training/epoch/loss"].log(loss)
        run["training/epoch/rmse"].log(rmse)
        # calculate the accuracy
#TODO
        # _, preds = torch.max(outputs.data, 1)   
        # train_running_correct += (preds == labels).sum().item()
##

        # backpropagation
        loss.backward()
        # update the optimizer parameters
        optimizer.step()
    
    # loss and accuracy for the complete epoch
    epoch_loss = train_running_loss / counter
    
    #epoch_acc = 100. * (train_running_correct / len(train_loader.dataset)) ## TODO ACCURACY
    return epoch_loss #, epoch_acc

# validation
def validate(model, criterion, val_loader, device=None):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    val_running_RMSE = 0.0
    counter = 0
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            counter += 1
            
            batch = recursive_todevice(batch, device)
            input, label = batch
            # forward pass
            outputs = model(input)
            # calculate the loss
            loss = criterion(outputs, label)
            valid_running_loss += loss.item()
            rmse = torch.sqrt(loss)
            val_running_RMSE += rmse
            ## Log losses to Neptune vis
            run["val/epoch/loss"].log(loss)
            run["val/epoch/rmse"].log(rmse)
            
            # calculate the accuracy
            # _, preds = torch.max(outputs.data, 1)
            # valid_running_correct += (preds == labels).sum().item()
        
    # loss and accuracy for the complete epoch
    epoch_loss = valid_running_loss / counter
    #epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss#, epoch_acc

# define how many epochs to train for

epochs = 5

# lists to keep track of losses and accuracies
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
# start the training
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(model=model, optimizer=optimizer, criterion=criterion, 
                                                 train_loader=train_loader, device = device)
    valid_epoch_loss = validate(model=model, criterion=criterion, 
                                                 val_loader=val_loader, device = device)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    print(f"Training loss: {train_epoch_loss:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}")
    # save the best model till now if we have the least loss in the current epoch
    save_best_model(
        valid_epoch_loss, epoch, model, optimizer, criterion
    )
    print('-'*50)
    
# save the trained model weights for a final time
save_model(epochs, model, optimizer, criterion)
# save the loss and accuracy plots

print('TRAINING COMPLETE')

I assume you are working on a multi-class segmentation using this UNet implementation.
If so, then note that the last layers in combination with the loss function look wrong.

  • A multi-class segmentation would use nn.CrossEntropyLoss or nn.NLLLoss as the criterion while you are using nn.MSELoss.
  • nn.LogSoftmax is used in combination with e.g. nn.NLLLoss to create log-probabilities.
  • The last layer is defined as nn.Linear(timesteps, 1) which could indicate a single output class (you are reshaping the output activation before applying logsoftmax so I’m unsure what the actual shape of your output is). Assuming the second dimension is 1 the final output would be a tensor full of zeros.

This Unet model is for multi class segmentation, however I need to adapt it for a regression problem.

I would need to output one segmentation mask for the time series input if this makes sense?

If you are dealing with a regression problem it seems that the usage of nn.LogSoftmax is wrong. Double check if this is causing an all zero output and if so, remove this activation.

Not quite, since a segmentation mask would not fit a regression use case unless you want to treat a classification as a regression.

Ok cheers, I will look into what the output predictions are.

The mask is a continuous data raster. When I’ve used pylighting with this for a single timestep prediction it works out ok, but it must be handling the output type.

But now with this custom model I’m not sure how to handle the regression output.

So I removed the SoftMax and have run some tests and the loss is decreasing. I am yet to visualize any of the predictions so I’m still not completely sure if it is configured 100% correct but it look more promising.

I assume for visualizing the predictions I wouldn’t need to run argamx() as it is a continuous output?

Yes, you should not use argmax on a single output channel as it will again yield an output full of zeros. torch.argmax is used on an output with a channel size of nb_classes to get the class index corresponding to the max. logit or probability.

1 Like

And just to confirm, is there any issue in the current configuration with having a 5d input, and 4d mask?

Yes, since a 5D model output and a 4D target could indicate a multi-class segmentation use case for 3D volumes. Since you are planning to treat it as a regression problem, I would assume both shapes should be identical, but as already explained I don’t fully understand your approach and don’t know if you would depend on broadcasting or want to remove a dimension from your model output.

so, my problem is I have a 12 timestep satellite images array (shape[12,15,256,256]), and one ‘mask’ (shape[1,256,256]), with the mask being a continuous array rather than a eg binary segmentation mask.

I am trying to rework this 3d unet for this problem, as it was suggested that it could work (rather than using temporal encodings).

I have checked out the outputs of the network, listed below:

def forward(self, x):
        x = x.float()
        print("input shape BTCHW", x.shape)
        x = x.permute(0, 2, 1, 3, 4)
        print("permute BCTHW", x.shape)
        out = x.cuda()
        en3 = self.en3(out)
        pool_3 = self.pool_3(en3)
        en4 = self.en4(pool_3)
        pool_4 = self.pool_4(en4)
        center_in = self.center_in(pool_4)
        center_out = self.center_out(center_in)
        concat4 = torch.cat([center_out,en4],dim=1)
        dc4 = self.dc4(concat4)
        trans3  = self.trans3(dc4)
        concat3 = torch.cat([trans3,en3],dim=1)
        dc3     = self.dc3(concat3)
        final   = self.final(dc3)
        final = final.permute(0,1,3,4,2) # BxCxHxWxT
        print("final permute BCHWT", final.shape)
        shape_num = final.shape[0:4]
        final = final.reshape(-1,final.shape[4])
        print("reshape1", final.shape)
        final = self.dropout(final)
        final = self.fn(final)
        print("linear", final.shape)
        final = final.reshape(shape_num)
        #final = self.logsoftmax(final)
        print("final", final.shape)

        return final

[INFO]: Epoch 1 of 5
Training
input shape BTCHW torch.Size([1, 12, 15, 256, 256])
permute BCTHW torch.Size([1, 15, 12, 256, 256])
final permute BCHWT torch.Size([1, 1, 256, 256, 12])
reshape1 torch.Size([65536, 12])
linear torch.Size([65536, 1])
final torch.Size([1, 1, 256, 256])