Is this the right way to approach a pretrained model?

I am trying to train a Text to Image DCGAN using skipthoughts
For the pytorch skipthoughts port I am using this library

Currently to convert my text to embeddings I am making use of the function I wrote below

def convert_text_to_embeddings(batch_text):
  tokenized_labels, lengths = tokenizer.convert_batch_sentences_to_indices(batch_text)
  # Feeding batches of text to uniskip model
  uni_out = uniskip(tokenized_labels, lengths)
  # Feeding batches of text to biskip model
  bi_out = biskip(tokenized_labels, lengths)
  return torch.cat([uni_out, bi_out], dim=1).detach()

If you notice before I return, I am detaching it
This is because when I run the training step on the generator and discriminator, I get an unreachable error.

# Train Generator
generator.zero_grad()
input_noise = torch.randn(size=(current_batch_size, noise_size)).cuda()
fake_images = generator(input_noise, text_embeddings)
discriminator_fake_validity = discriminator(fake_images, text_embeddings)
generator_loss = adversarial_loss(discriminator_fake_validity, real_validity)

generator_loss.backward() # Error occurs here even with generator_loss.backward(retain_graph=True)
generator.optimizer.step()
generator_losses.append(generator_loss)

I was just wondering if I am making use of pretrained embeddings the right way, should I or should I not be calling detach? Also while creating the uniskip and biskip models I set them to eval mode
Are these steps right? Any help would be appreciated!

PS: If I have missed out on anything, or you would like to know more about how I implemented certain things (if my examples are insufficient or not properly explained) do let me know, I would try to be as clear as possible.

I’m not familiar with the mentioned library, but could you post the complete error message you are seeing, if you are not detaching?
Currently, the “skip” layers wouldn’t be trained and I’m not sure, if this fits your use case.

1 Like

Hey! Thank you for replying

This is the error log I see if I am not detaching (I set the model back to train mode)

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

In general I’d just love some advice on how to use pretrained embeddings.
I have also made use of hugging face transformers (bert) and I followed a similar approach wherein I set the model to eval mode (and output_hidden_states to True)
I then created a similar function (I referred to this)

def convert_text_to_embeddings(batch_text):
    encoded_input = tokenizer(batch_text, return_tensors='pt', padding=True, max_length=69).to('cuda')
    embedding = bert_model(**encoded_input)[2]
    output = torch.mean(embedding[-1], dim=1).squeeze()
    return output.detach()

My aim is to train a GAN based on text (Text to Image) so I just trying to clarify if my approach is the right one :sweat_smile:
Thank you!

If you don’t want to train the embedding, then you can detach the output.
To use a pretrained embedding you would have to make sure you are loading the right weights.
Besides that it will just work as a lookup table.

The RuntimeError regarding the second backward pass seems to point to another issue and it’s a bit weird that the detach() solves it.

Could you post a code snippet to reprodice this issue using a random embedding?

1 Like

To download the skipthoughts library

!pip install skipthoughts

Importing libraries and initialising them
If it helps here is the link on how to use it, I followed their steps -> link

from skipthoughts import UniSkip, BiSkip
dir_st = 'skip-thoughts'
# Note that the word dict here is something I created
uniskip = UniSkip(dir_st, tokenizer.word_dict).cuda() pretrained
biskip = BiSkip(dir_st, tokenizer.word_dict).cuda()

I use the function below to convert any sentence to embeddings

def convert_text_to_embeddings(batch_text):
  tokenized_labels, lengths = tokenizer.convert_batch_sentences_to_indices(batch_text)
  uni_out = uniskip(tokenized_labels, lengths)
  bi_out = biskip(tokenized_labels, lengths)
  return torch.cat([uni_out, bi_out], dim=1) # Note I add detach here to avoid the error

Below is my training step (please let me know if you would like to see my model architecture too)

pbar = tqdm()

# Converting dataset to cuda

for epoch in range(epochs):
    print(f'Epoch: {epoch + 1} / {epochs}')
    pbar.reset(total=len(weighted_dataloader))

    # Setting up losses
    discriminator_losses = []
    generator_losses = []

    for i, (real_images, real_text) in enumerate(weighted_dataloader):

        # Current batch size
        current_batch_size = real_images.size()[0]

        # Convert to cuda
        real_images = real_images.cuda()
        text_embeddings = convert_text_to_embeddings(real_text)

        # For real vs fake
        real_validity = torch.ones(current_batch_size, 1).cuda()
        fake_validity = torch.zeros(current_batch_size, 1).cuda()

        # Train Generator
        generator.zero_grad()
        input_noise = torch.randn(size=(current_batch_size, noise_size)).cuda()
        fake_images = generator(input_noise, text_embeddings)
        discriminator_fake_validity = discriminator(fake_images, text_embeddings)
        generator_loss = adversarial_loss(discriminator_fake_validity, real_validity)

        generator_loss.backward() # Does not work even if I set retain_graph to True
        generator.optimizer.step()
        generator_losses.append(generator_loss)

        # Train Discriminator
        discriminator.zero_grad()

        ## To calculate real loss
        discriminator_real_validity = discriminator(real_images, text_embeddings)
        discriminator_real_loss = adversarial_loss(discriminator_real_validity, real_validity)

        ## To calculate wrong loss
        mismatched_images = torch.cat([real_images[1:], real_images[:1]], dim=0)
        discriminator_wrong_validity = discriminator(mismatched_images, text_embeddings)
        discriminator_wrong_loss = adversarial_loss(discriminator_wrong_validity, fake_validity)

        ## To calculate fake loss
        discriminator_fake_validity = discriminator(fake_images.detach(), text_embeddings)
        discriminator_fake_loss = adversarial_loss(discriminator_fake_validity, fake_validity)

        ## Calculating total loss
        discriminator_loss = discriminator_real_loss + discriminator_wrong_loss + discriminator_fake_loss
        discriminator_loss.backward()
        discriminator.optimizer.step()
        discriminator_losses.append(discriminator_loss)


        # Update tqdm
        pbar.update()

    print('Discriminator Loss: {:.3f}, Generator Loss: {:.3f}'.format(
            torch.mean(torch.FloatTensor(discriminator_losses)),
            torch.mean(torch.FloatTensor(generator_losses))
    ))
    
    if (epoch + 1) % output_after_every_n == 0:
        plot_output()

pbar.refresh()

Do let me know if I have missed out on anything you’d like to know.
Thanks!