torch.nn.BCELoss are unsafe to autocast while working with cosine similarity

I’m trying to modify CLIP network GitHub - openai/CLIP: Contrastive Language-Image Pretraining
This network receive pair of images and texts and return matrix of cosine similarity between each text and each image.

The training code is something like this :

# define BS as batch_size, optimizer 
loss_per_img = nn.BCELoss()
loss_per_txt = nn.BCELoss()

for batch in dataloader :
  images,text = batch # Image size is (BS,) and text also (BS,)
  model = CLIP() #Actually not like this but just assume this
  cosine_per_image,cosine_per_text = model(images,texts) # cosine_per_image dimension is (BS,BS) where each value represent cosine similarity.cosine_per_text just transpose of cosine_per_image
  loss_total= (loss_per_img(cosine_per_image,ground_truth ) + loss_per_txt(cosine_per_text ,ground_truth ))/2
  loss_total.backward()
  optimizer.step()
  optimizer.zero_grad()

And this code works fine. Now I want to improve the speed by introducing mixed precision training. Here’s my implementation

use_amp=True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
for batch in dataloader : 
  with torch.cuda.amp.autocast(enabled=use_amp):
  ...<same as before>
 loss_total= (loss_per_img(logits_per_image,ground_truth ) + logits_per_text(loss_per_txt ,ground_truth ))/2
scaler.scale(total_loss).backward()
scaler.step(optimizer)
scaler.update()

and this code will give error about torch.nn.BCELoss are unsafe to autocast. and ask me to use BCEWithLogitsLoss instead. The problem is, the cosine similarity from CLIP is calculated from vectors product, not from sigmoid or softmax layer. That’s why I choose BCELoss directly.

Why not sigmoid? Because the value is cosine similarity, not logits for sigmoid.
Also, I can’t use softmax to the cosine similarity matrix since the task is multi-target classification. I don’t want the probability to overly saturated to the class wit highest probability.

Any idea to use the mixed precision training without involving the sigmoid/softmax ?

You could either disable autocast for the loss calculation or implement a custom loss method, which you could make “autocast-safe”.

1 Like