I’m trying to train CLIP in my own dataset, The model is not learning anything, the validation loss doesn’t reduce after the first epoch. I’m attaching my training code here, Please LMK whether I make any mistake.
def train_epoch(epoch, model, trainloader, optim, train_image_loss, train_txt_loss):
model.train()
train_loss = 0.0
pbar = tqdm(enumerate(trainloader), total=len(trainloader))
for i, batch in pbar:
image, caption = batch
image = image.to(device).float()
caption = caption.to(device).long()
image_logits, text_logits = model(image, caption)
gt = torch.arange(cfg.batch).to(device)
total_train_image_loss = train_image_loss(image_logits, gt)
total_train_text_loss = train_text_loss(text_logits, gt)
total_train_loss = (total_train_image_loss + total_train_text_loss)/2
total_train_loss.backward()
train_loss += total_train_loss.item()
train_epoch_loss = train_loss / len(trainloader)
return train_epoch_loss
def valid_epoch(epoch, model, valloader, val_image_loss, val_txt_loss):
model.eval()
val_loss = 0.0
pbar = tqdm(enumerate(valloader), total=len(valloader))
for i, batch in pbar:
image, caption = batch
image = image.to(device).float()
caption = caption.to(device).long()
with torch.no_grad():
image_logits, text_logits = model(image, caption)
gt = torch.arange(cfg.batch).to(device)
total_val_loss = (val_image_loss(image_logits, gt) + val_txt_loss(text_logits, gt))/2
val_loss += total_val_loss.item()
val_epoch_loss = val_loss / len(valloader)
return val_epoch_loss
for i in range(FOLD_NUM):
RUN_NAME = f"Finetune_Fold_{i}"
labels = [x.strip() for x in open("labels.txt", "r").readlines()]
images = len(glob.glob("resized2/*/*"))
samples_per_class = [len(os.listdir(f"resized2/{x}")) for x in labels]
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
clip.model.convert_weights(model)
for params in model.parameters():
params.requires_grad_(True)
train_data = CLIPDataset(data, preprocess, i, mode="train", aug=cfg.aug_mode)
val_data = CLIPDataset(data, preprocess, i, mode="val", aug=cfg.aug_mode)
trainloader = DataLoader(
train_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True
)
valloader = DataLoader(val_data, batch_size=cfg.batch, num_workers=16, pin_memory=True, drop_last=True)
# optimizer = torch.optim.Adam(model.parameters(), lr=cfg.initial_lr, momentum=cfg.momentum, nesterov=cfg.nesterov)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset
print("model, optimizer and scheduler are loaded and ready to go..........")
train_image_loss, train_text_loss = nn.CrossEntropyLoss(), nn.CrossEntropyLoss()
val_image_loss, val_text_loss = nn.CrossEntropyLoss(), nn.CrossEntropyLoss()
with mlflow.start_run():
best_val_loss = float("inf")
no_improve_counter = 0
for e in range(cfg.epochs):
train_epoch_loss = train_epoch(e, model, trainloader, optimizer, train_image_loss, train_text_loss)
mlflow.log_metric(key="train_loss", value=train_epoch_loss, step=e)
val_epoch_loss = valid_epoch(e, model, valloader, val_image_loss, val_text_loss)
mlflow.log_metric(key="val_loss", value=val_epoch_loss, step=e)
if val_epoch_loss < best_val_loss:
print(f"The model improved from {best_val_loss} to {val_epoch_loss}")
best_val_loss = val_epoch_loss
torch.save({
"epoch": e,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, f"{cfg.model_dir}/clip_finetuned_fold{i}_loss{val_epoch_loss:.4f}.pth")
else:
print(f"No Improvements: {val_epoch_loss}")
no_improve_counter += 1
if no_improve_counter % cfg.counter_patience == 0:
break