Trouble using clip, loss converges at 0.8

Hi,

I’m trying to train a network that will learn to manipulate an image (e.g. change hair color). To guide the network I use the clip model, and the mapping is above the latent space of vqgan.

When training, the loss converges to 0.8. Conversely when the loss consists only MSE term between the encoding and the mapping, it drops to 0. I can’t find the reason why the network gets entangled when using clip loss.

I’ve checked different batch sizes, different learning rates, tried to add clip/vqgan models to the network itself. Nothing works.

Would appreciate help!

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import PIL
import torchvision.transforms as T
from omegaconf import OmegaConf
import yaml
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
from torch.optim import Adam
from criteria.clipush import CLIPLoss
from CLIP import clip
from taming_transformers.taming.models.vqgan import VQModel

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config


def load_vqgan(config, ckpt_path=None):
    model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()


def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x


def preprocess(img, target_image_size=224):
    s = min(img.size)

    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')

    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = T.ToTensor()(img)
    return preprocess_vqgan(img)


def preprocess_vqgan(x):
    x = 2. * x - 1.
    return x


config_vqf4 = load_config("logs/vq-f4/configs/config.yaml", display=False)
model_vqf4 = load_vqgan(config_vqf4, ckpt_path="logs/vq-f4/checkpoints/model.ckpt").to(DEVICE)


# Define a convolution neural network
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.LatentMapper = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 3, (3, 3), (1, 1), 1),
            nn.LeakyReLU()
        )

    def forward(self, batch):
        batch_encode = (model_vqf4.encode(batch.to(DEVICE)))[0]
        batch_mapping = self.LatentMapper(batch_encode)
        batch_decode = (model_vqf4.decode(batch_mapping.to(DEVICE)))
        return batch_encode, batch_mapping, batch_decode


# Instantiate a neural network model
model = Network()
loss_fn = nn.MSELoss().to(DEVICE).eval()
clip_global = CLIPLoss().eval()
optimizer = Adam(model.LatentMapper.parameters(), lr=0.005, weight_decay=0.0001)


# Function to save the model
def saveModel():
    path = "./myFirstModel.pth"
    torch.save(model.state_dict(), path)


# Training function. We simply have to loop over our data iterator and feed the inputs to the network and optimize.
def train(max_steps, batch_size):
    model.to(DEVICE)

    with torch.no_grad():
        dataset = ImageFolder(r'\PycharmProjects\latent_to_latent\images1024x1024', transform=preprocess)
        train_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=3)

    steps_counter = 0
    text_tar = ["person with blond hair"]
    text_target_tokenized = clip.tokenize(text_tar).to(device="cuda")

    while steps_counter < max_steps:

        for i, (images, labels) in enumerate(train_loader, 0):
            # get the inputs
            images = images.to(DEVICE)
            # zero the parameter gradients
            optimizer.zero_grad()
            # predict classes using images from the training set
            batch_encode, batch_mapping, batch_decode = model(images)

            # summing loss
            simi = 1 - clip_global.model(batch_decode, text_target_tokenized)[0] / 100
            loss = simi.mean() + loss_fn(batch_encode, batch_mapping)

            # backpropagation the loss
            loss.backward()
            # adjust parameters based on the calculated gradients
            optimizer.step()

            print(loss)

            if steps_counter % 500 == 0:
                saveModel()

            steps_counter += 1

            if steps_counter == max_steps:
                print('OMG, finished training!')
                saveModel()
                break


def main():
    train(40000, 1)


if __name__ == '__main__':
    main()

and clipush contains:


import torch
from CLIP import clip


class CLIPLoss(torch.nn.Module):

    def __init__(self):
        super(CLIPLoss, self).__init__()
        self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")

    def forward(self, image, text):
        similarity = 1 - self.model(image, text)[0] / 100
        return similarity

but I don’t use forward.