Pretrained FCN not training

Hey, I am trying to train a FCN model(https://pytorch.org/vision/stable/models/generated/torchvision.models.segmentation.fcn_resnet50.html#torchvision.models.segmentation.fcn_resnet50) on a segmentation dataset however the loss doesn’t seem to decrease and converge and I’m totally clueless why. I do my backward and forward propagation as well and zero grad my optimizer and .train() my model when needed.

This is my training code:

import torch 
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from torchsummary import summary


if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    print('Running on the GPU')
else:
    DEVICE = "cpu"
    print('Running on the CPU')

MODEL_PATH = '/content/best_model.pth'
LOAD_MODEL = False
BATCH_SIZE = 16 
LEARNING_RATE = 0.001
EPOCHS = 100
CLASSES = ['grass', 'weed', 'crop']
TRANSFORMS = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def check_accuracy(loader, model, num_classes=3, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device).float()
            y = y.to(device).float()
            
            preds = model(x)
            preds = torch.sigmoid(preds['out'])
            for i in range(0, num_classes):

                preds_ = (preds[:, i, :, :] > 0.5).float()

                y_ = y[:, i, :, :]
                num_correct += (preds_ == y_).sum()
                
                num_pixels += torch.numel(preds_) 
                dice_score += (2 * (preds_ * y_).sum()) / (
                    (preds_ + y_).sum() + 1e-8
                )

    print(
       f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    # print(f"Dice score: {dice_score/len(loader)}")
    model.train()

    return dice_score/(len(loader) * num_classes)


def train_function(data, model, optimizer, loss_fn, device):
    print('Entering into train function')
    loss_fn.requires_grad = True
    loss_values = []
    data = tqdm(data)
    model.train()
    for index, batch in enumerate(data): 
        X, y = batch
        X, y = X.to(device), y.to(device)
        #X = X.permute(0, 3, 1, 2).float()
        #y = y.permute(0, 3, 1, 2).float()
        #print(X.shape)
        #print(y.shape)
        optimizer.zero_grad()
        preds = model(X)
        #visualize(image=preds['out'][0].permute(1, 2, 0).cpu())
        preds['out'].requires_grad = True
        loss = loss_fn(preds['out'], y)
        loss.backward()
        optimizer.step()

        data.set_description(f"Loss: {loss.item()}")
    
    print(f"Dice Score: {check_accuracy(data, model)}")

    return loss.item()
        

def main():
    global epoch
    epoch = 0 # epoch is initially assigned to 0. If LOAD_MODEL is true then
              # epoch is set to the last value + 1. 
    LOSS_VALS = [] # Defining a list to store loss values after every epoch

    train_dataset = WeedsDataset("/content/dataset/images", '/content/dataset/annotations', classes=['grass', 'weed', 'crop'],
                                 process=TRANSFORMS, process_mask=TRANSFORMS_MASK)


    train_set = DataLoader(train_dataset, batch_size=8, shuffle=True)

    print('Data Loaded Successfully!')

    # Defining the model, optimizer and loss function
    model = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet50', num_classes=3, pretrained=False).train().to(DEVICE)

    

    ct = 0
    for child in model.children():
        ct += 1
        if ct < 10:
            for param in child.parameters():
                param.requires_grad = False

    params = [p for p in model.parameters() if p.requires_grad]
    
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_function = nn.BCEWithLogitsLoss().to(DEVICE)

    # Loading a previous stored model from MODEL_PATH variable
    if LOAD_MODEL == True:
        checkpoint = torch.load(MODEL_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        epoch = checkpoint['epoch']+1
        LOSS_VALS = checkpoint['loss_values']
        print("Model successfully loaded!")    

    #Training the model for every epoch. 
    for e in range(epoch, EPOCHS):
        print(f'Epoch: {e}')
        loss_val = train_function(train_set, model, optimizer, loss_function, DEVICE)
        
        LOSS_VALS.append(loss_val) 
        torch.save({
            'model_state_dict': model.state_dict(),
            'optim_state_dict': optimizer.state_dict(),
            'epoch': e,
            'loss_values': LOSS_VALS
        }, MODEL_PATH)

        
        print("Epoch completed and model successfully saved!")

main()

The loss:

Let me know if there is anything that I can fix.

These lines of code look wrong as it should not be needed to set the .requires_grad attribute of the model output to True unless it was detached.
If so, then the previously used parameters to calculate preds will not be updated which could explain the issue you are running into.
If you are seeing an error if you are not using preds['out'].requires_grad = True, check which operation in the model’s forward method could detach the tensor (e.g. by rewrapping an activation into a torch.tensor, by using a 3rd party library such as numpy, or by explicitly calling detach() on a tensor).

This is where I define my model, I’m not sure if part of the model would be detached from the graph because I specified .train() as well as it being a pretrained model.

No, calling train() or using a pretrained model will not detach the computation graph.
In case you are freezing all parameters of the model there won’t be a computation graph in the first place. Did you check what happens if you are removing the preds['out'].requires_grad = True call?

Yep I checked and it returns the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Nevermind, I figured out the problem, thanks for all your help!