Pytorch Object detection model does not update upon taking a step

Hi all,

I am trying to train an object detector in pytorch however even though my gradients are non zero and losses are generated when I use the step function my models parameters do not change.

Here is my code:

import os
from statistics import mode
import torch
from PIL import Image

from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import retinanet_resnet50_fpn


from torchvision import transforms
import pandas as pd

from torchsummary import summary
from tqdm import tqdm

import matplotlib.pyplot as plt

class DataGen(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_folder):
        df = pd.read_csv(csv_path)
        df = df.reset_index()
        self.images_folder = images_folder
        self.images_name = df['image_name']
        self.class_id=pd.to_numeric(df['class_id'])
        self.length = len(df)
        
        self.xmin=df['xmin']
        self.xmax=df['xmax']
        self.ymin=df['ymin']
        self.ymax=df['ymax']

    def __getitem__(self, index):
        # load images
        img_path=os.path.join(self.images_folder,
                              self.images_name[index])


        input_image = Image.open(img_path)
        preprocess = transforms.Compose([
            
            transforms.ToTensor(),


        ])      
        img = preprocess(input_image)


        # TODO: Add multiple images into one batch
        img = img.reshape([1,img.shape[0],img.shape[1],img.shape[2]])


        # TODO: Implement multiple bbox
        xmin=self.xmin[index]/img.shape[1]
        xmax=self.xmax[index]/img.shape[1]
        ymin=self.ymin[index]/img.shape[2]
        ymax=self.ymax[index]/img.shape[2]


        # convert everything into a torch.Tensor
        boxes = torch.as_tensor([[xmin, ymin, xmax, ymax]], dtype=torch.float32)

        labels = torch.as_tensor([int(self.class_id[index])], dtype=torch.int64)



        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        return img, target

    def __len__(self):
        return len(self.images_name)

def train(model, optimizer, lr_scheduler, train_set, val_set, device, epochs, scaler ):
    running_training_loss = []
    running_validation_loss = []
    # TODO: Add in checkpointing
    # TODO: Add in early stopping
    for epoch in range(epochs):
        epoch_loss = 0.0
        print("Starting Epoch ", epoch+1, "/",epochs)
        model.train()
        a = list(model.parameters())[0].clone()
        for x in tqdm(range(len(train_set))):
            images, target = train_set[x]
            images = list(image.to(device) for image in images)
            target = {k: v.to(device) for k, v in target.items()} 
            loss_dict = model(images, [target])
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                scaler.scale(losses).backward()
                # grads = []
                # for param in model.parameters():
                #     if param.requires_grad:
                #         grads.append(param.grad.view(-1))
                # grads = torch.cat(grads)
                # print(grads)

                # for param in model.parameters():
                    
                #     if param.requires_grad:
                #         print(param.grad.data.sum())
                scaler.step(optimizer)
                scaler.update()
                if lr_scheduler is not None:
                    lr_scheduler.step()
            epoch_loss += losses.item()
        b = list(model.parameters())[0].clone()
        print(torch.equal(a.data, b.data))
        running_training_loss.append(epoch_loss) 
        print("Total training loss: ", str(epoch_loss))
        print("Average training loss: ", str(epoch_loss/len(train_set)))
        epoch_loss = 0.0
        print("Validating...")
        model.eval()
        total_correct = 0
        with torch.no_grad():
            for x in tqdm(range(len(val_set))):
                images, target = val_set[x]
                images = list(image.to(device) for image in images)
                target = {k: v.to(device) for k, v in target.items()} 
                output = model(images)[0]

                correct = False
                for x in output['labels']:
                    if x.item() == target['labels'].item():
                        correct = True
                if correct:
                    total_correct +=1
                # TODO: Implement IOU accuracy
                # TODO: Calculate loss

        running_validation_loss.append(epoch_loss)
        print("Total validation loss: ", str(epoch_loss))
        print("Average validation loss: ", str(epoch_loss/len(val_set)))
        print("Validation Accuracy: ", str(total_correct/len(val_set)))
        if epoch !=0:
            if running_validation_loss[epoch-1] > running_validation_loss[epoch]:
                torch.save(model.state_dict(), "models/first_try.pt")   
    return model, running_training_loss, running_validation_loss

@torch.inference_mode() # Dont know why but this makes inference faster
def test(model, test_set, device):
    total_correct = 0
    model.eval()
    with torch.no_grad():
        for x in tqdm(range(len(test_set))):
            images, target = test_set[x]
            images = list(image.to(device) for image in images)
            output = model(images)[0]
            correct = False
            # TODO: Implement IOU accuracy
            for x in output['labels']:
                if x.item() == target['labels'].item():
                    correct = True
            if correct:
                total_correct +=1

    return total_correct/len(test_set)



def main():
    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    dataset_train = DataGen('data/Suspension_10/train_small.csv','data/Suspension_10/images' )
    dataset_val = DataGen('data/Suspension_10/validation_small.csv','data/Suspension_10/images')


    # model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT", box_score_thresh=0.7)
    model = retinanet_resnet50_fpn(weights="DEFAULT")

    # in_features = model.roi_heads.box_predictor.cls_score.in_features
    # # replace the pre-trained head with a new one
    # model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 8)
    model.to(device)


    params = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)
                              

    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)

    scaler = torch.cuda.amp.GradScaler()

    model, train_loss, validation_loss = train(model, optimizer, lr_scheduler, dataset_train, dataset_val, device, 5, scaler)

    fig = plt.figure(figsize=(10,8))
    plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
    plt.plot(range(1,len(validation_loss)+1),validation_loss,label='Validation Loss')
    plt.title("Learning Curve Graph")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.ylim(0, 1) # consistent scale
    plt.xlim(0, len(train_loss)+1) # consistent scale
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()
    fig.savefig('model_training.png', bbox_inches='tight')

    dataset_test = DataGen('data/Suspension_10/test.csv','data/Suspension_10/images')
    test_accuracy = test(model, dataset_test, device)
    print("Test accuracy:", test_accuracy)

    


main()

Any ideas on what im doing wrong or things I should check to make sure I am training correctly

Are you seeing warnings of invalid gradients coming from amp? Based on your code snippet I would assume that the first parameter updates would be skipped due to a high loss scaling factor. If all updates are skipped it could indicate an invalid model output or loss value which would then also cause invalid gradients.

No warnings or errors the code runs all the way through. I have tried running the code with and without the scalar to no avail

My output for the loss dic is as follows:

{'classification': tensor(2.2468, device='cuda:0', grad_fn=<DivBackward0>), 'bbox_regression': tensor(5.4697, device='cuda:0', grad_fn=<DivBackward0>)}

and my grad outputs look something like this:

tensor(-413431.7812, device='cuda:0')
tensor(-2979375., device='cuda:0')
tensor(795280.3750, device='cuda:0')
tensor(864209.1875, device='cuda:0')
tensor(51765.6133, device='cuda:0')
tensor(414170.1250, device='cuda:0')
tensor(162745.4062, device='cuda:0')
tensor(102840.3828, device='cuda:0')
tensor(2653697., device='cuda:0')
tensor(1541843.7500, device='cuda:0')
tensor(-911746.6250, device='cuda:0')
tensor(5526663.5000, device='cuda:0')
tensor(568972.3750, device='cuda:0')