CLIP training - no progression

I’m trying to train CLIP in my own dataset, The model is not learning anything, the validation loss doesn’t reduce after the first epoch. I’m attaching my training code here, Please LMK whether I make any mistake.

def train_epoch(epoch, model, trainloader, optim, train_image_loss, train_txt_loss):
    model.train()
    train_loss = 0.0
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    for i, batch in pbar:
        image, caption = batch
        image = image.to(device).float()
        caption = caption.to(device).long()

        image_logits, text_logits = model(image, caption)

        gt = torch.arange(cfg.batch).to(device)

        total_train_image_loss = train_image_loss(image_logits, gt)
        total_train_text_loss = train_text_loss(text_logits, gt)
        
        total_train_loss = (total_train_image_loss + total_train_text_loss)/2
        total_train_loss.backward()

        train_loss += total_train_loss.item()
    
    train_epoch_loss = train_loss / len(trainloader)

    return train_epoch_loss


def valid_epoch(epoch, model, valloader, val_image_loss, val_txt_loss):
    model.eval()
    val_loss = 0.0
    pbar = tqdm(enumerate(valloader), total=len(valloader))
    for i, batch in pbar:
        image, caption = batch
        image = image.to(device).float()
        caption = caption.to(device).long()
        
        with torch.no_grad():
            image_logits, text_logits = model(image, caption)

        gt = torch.arange(cfg.batch).to(device)

        total_val_loss = (val_image_loss(image_logits, gt) + val_txt_loss(text_logits, gt))/2

        val_loss += total_val_loss.item()
    
    val_epoch_loss = val_loss / len(valloader)

    return val_epoch_loss


        


for i in range(FOLD_NUM):
    RUN_NAME = f"Finetune_Fold_{i}"
    

    labels = [x.strip() for x in open("labels.txt", "r").readlines()]
    images = len(glob.glob("resized2/*/*"))
    samples_per_class = [len(os.listdir(f"resized2/{x}")) for x in labels]

    model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
    clip.model.convert_weights(model)
    
    for params in model.parameters():
        params.requires_grad_(True)
        
    

    train_data = CLIPDataset(data, preprocess, i, mode="train", aug=cfg.aug_mode)
    val_data = CLIPDataset(data, preprocess, i, mode="val", aug=cfg.aug_mode)

    trainloader = DataLoader(
        train_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True
    )
    valloader = DataLoader(val_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True)
    
    # optimizer = torch.optim.Adam(model.parameters(), lr=cfg.initial_lr, momentum=cfg.momentum, nesterov=cfg.nesterov)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset


    print("model, optimizer and scheduler are loaded and ready to go..........")


    train_image_loss, train_text_loss = nn.CrossEntropyLoss(), nn.CrossEntropyLoss()
    val_image_loss, val_text_loss = nn.CrossEntropyLoss(), nn.CrossEntropyLoss()

    
    with mlflow.start_run():       
        best_val_loss = float("inf")        
        no_improve_counter = 0
            
        for e in range(cfg.epochs):
            train_epoch_loss = train_epoch(e, model, trainloader, optimizer, train_image_loss, train_text_loss)
            mlflow.log_metric(key="train_loss", value=train_epoch_loss, step=e)

            val_epoch_loss = valid_epoch(e, model, valloader, val_image_loss, val_text_loss)
            mlflow.log_metric(key="val_loss", value=val_epoch_loss, step=e)

            if val_epoch_loss < best_val_loss:
                print(f"The model improved from {best_val_loss} to {val_epoch_loss}")
                best_val_loss = val_epoch_loss
                torch.save({
                    "epoch": e,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict()
                }, f"{cfg.model_dir}/clip_finetuned_fold{i}_loss{val_epoch_loss:.4f}.pth")
            

            else:
                print(f"No Improvements: {val_epoch_loss}")
                no_improve_counter += 1
                
                if no_improve_counter % cfg.counter_patience == 0:
                    break

It looks like your train function is passed an optimizer optim but it is then never used in the function itself, which would explain why the loss never changes as the model parameters would never be updated.

Note that even though gradients are being computed, you would need an optimizer step to update the model parameters.

yeah that’s right since I was in a bit of haste, I forgot this totally. Thanks a lot @eqy

I have added the code for the optimizer but still the model is not learning and the training is saturated at one point

You may want to check that the model parameters are actually changing with each iteration, e.g., by summing them and checking if the total value changes.


The loss actually jumps so higher and eventually becomes NaN. I also tried gradient clipping but nothing works

Could you share how you added your optimizer to the training loop?

Training Script

def train_epoch(epoch, model, trainloader, optim, train_loss_fn, logit_scale, neptune_logger):
    
    model.train()
    train_loss = 0.0
    batch_norm = []
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    for i, batch in pbar:
        image, caption = batch
        image = image.to(device).float()
        caption = caption.to(device).long()

        # image_logits, text_logits = model(image, caption)
        
        image_logits = model.encode_image(image)
        text_logits = model.encode_text(caption)
        
      
        total_train_loss = train_loss_fn(image_logits, text_logits, logit_scale)
        total_train_loss.backward()

        #gradient clipping to avoid exploding gradient problem
        nn.utils.clip_grad_value_(model.parameters(), clip_value=2.0)

        # replace_layers(model)
        optim.step()

        pbar.set_description(f"Train_Step_Loss:{total_train_loss.item()}")
        train_loss += total_train_loss.item()
        
    train_epoch_loss = train_loss / len(trainloader)
    neptune_logger["train/epoch_loss"].append(train_epoch_loss)

    return train_epoch_loss, batch_norm


def valid_epoch(epoch, model, valloader, val_loss_fn, logit_scale, neptune_logger):
    model.eval()
    val_loss = 0.0
    pbar = tqdm(enumerate(valloader), total=len(valloader))
    for i, batch in pbar:
        image, caption = batch
        image = image.to(device).float()
        caption = caption.to(device).long()
        
        with torch.no_grad():
            # image_logits, text_logits = model(image, caption)
            image_logits = model.encode_image(image)
            text_logits = model.encode_text(caption)
        
        total_val_loss = val_loss_fn(image_logits, text_logits, logit_scale)
        val_loss += total_val_loss.item()
    
    val_epoch_loss = val_loss / len(valloader)
    neptune_logger["val/epoch_loss"].append(val_epoch_loss)
    
    return val_epoch_loss


        


for i in range(FOLD_NUM):
    RUN_NAME = f"Finetune_Fold_{i}_{date}"
    
    run = neptune.init_run(project=f"tensorthiru/{EXPERIMENT_NAME}", custom_run_id=RUN_NAME)
    
    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    
    labels = [x.strip() for x in open("labels.txt", "r").readlines()]
    images = len(glob.glob("resized2/*/*"))
    samples_per_class = [len(os.listdir(f"resized2/{x}")) for x in labels]
    
    m = torch.jit.load("ViT-B-32.pt").state_dict()
    model, preprocess = build_model_preprocess(m)
    model = model.to("cuda:0")

    train_data = CLIPDataset(data, preprocess, i, mode="train", aug=cfg.aug_mode)
    val_data = CLIPDataset(data, preprocess, i, mode="val", aug=cfg.aug_mode)

    trainloader = DataLoader(
        train_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True
    )
    valloader = DataLoader(val_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, eps=1e-04)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset


    print("model, optimizer and scheduler are loaded and ready to go..........")


    train_loss = ClipLoss()
    val_loss = ClipLoss() 

    
    
    best_val_loss = float("inf")        
    no_improve_counter = 0

    for e in range(cfg.epochs):
        train_epoch_loss, batch_norm = train_epoch(e, model, trainloader, optimizer, train_loss, logit_scale, run)

        val_epoch_loss = valid_epoch(e, model, valloader, val_loss, logit_scale, run)

        if val_epoch_loss < best_val_loss:
            print(f"The model improved from {best_val_loss} to {val_epoch_loss}")
            best_val_loss = val_epoch_loss
            torch.save({
                "epoch": e,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            }, f"{cfg.model_dir}/clip_finetuned_fold{i}_loss{val_epoch_loss:.4f}.pth")


        else:
            print(f"No Improvements: {val_epoch_loss}")
            no_improve_counter += 1
            if no_improve_counter % cfg.counter_patience == 0:
                break

CLIP Loss Function

class ClipLoss(nn.Module):

    def __init__(
            self,
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.world_size > 1 and self.local_loss:
                labels = labels + num_logits * self.rank
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        if self.world_size > 1:
            all_image_features, all_text_features = gather_features(
                image_features, text_features,
                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

            if self.local_loss:
                logits_per_image = logit_scale * image_features @ all_text_features.T
                logits_per_text = logit_scale * text_features @ all_image_features.T
            else:
                logits_per_image = logit_scale * all_image_features @ all_text_features.T
                logits_per_text = logits_per_image.T
        else:
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T
        
        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale, output_dict=False):
        device = image_features.device
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        return {"contrastive_loss": total_loss} if output_dict else total_loss

You do not have a zero_grad call in your training loop torch.optim.Optimizer.zero_grad — PyTorch 2.0 documentation which would cause the gradients to be accumulated across all batches rather than being correctly calculated for each batch. Eventually these would also overflow and explain the NaN loss.


Even then the step loss remains the same, no improvements in it. Is there any other correction that I have to do in the training script

I would verify that your model is capable of overfitting a trivially small dataset (e.g., just one or a few examples) to rule out any other errors in the training loop. Also are you zero’ing the grads before or after each optimizer step? Note that if you should zero it before the backward call or after the optimizer step, but not between the two, as in this case the optimizer step will do nothing.