Multiple model.forward / Multiple models.forward followed by one loss.backward

Hello everyone,

I want to know the best implementation out of three similar implementations regarding training a bi-encoder model in PyTorch with NLL (as a triplet loss) in terms of less memory allocation. Moreover, there are anchor, positive and negative outputs for the NLL and the differences among the implementations are:

  1. uses only one model and do a forward pass in each batch iteration for positives and negatives
  2. uses two models (with two optimizers) and do a forward pass in each batch iteration for positives and negatives
  3. uses one model and do a forward pass over all anchors, positives and negatives

My goal is to reduce the allocated memory with positives and negatives in order to fit more negatives in the GPU for each batch. I am not sure if the two first ones allocate less memory because they might be creating a separate computational graph for each forward pass. Also, I must do a forward pass over all anchors at once and call loss.backward(retain_graph=True) over batches of anchors, this is common in all implementations.

The problem with the third one is that I compute all the positives and negatives at once, and they are occupying space in memory.

  1. Implementation 1 (pseudocode)
optimizer.zero_grad()

all_anchors_embed = bi_encoder.forward(all_anchors_input)

for b_anchors_embed, b_pos_inputs, b_neg_inputs in zip(all_anchors_embed, all_pos_input, all_neg_input):
    b_pos_embed, b_neg_embed = bi_encoder.forward(b_pos_inputs, b_neg_inputs)

    loss = NLL(b_anchors_embed, b_pos_embed, b_neg_embed)
    loss.backward(retain_graph=True)

optimizer.step()
  1. Implementation 2: (pseudocode)
optimizer_encoder_1.zero_grad()
optimizer_encoder_2.zero_grad()

all_anchors_embed = encoder_1.forward(all_anchors_input)

for b_anchors_embed, b_pos_inputs, b_neg_inputs in zip(all_anchors_embed, all_pos_input, all_neg_input):
    b_pos_embed, b_neg_embed = encoder_2.forward(b_pos_inputs, b_neg_inputs)

    loss = NLL(b_anchors_embed, b_pos_embed, b_neg_embed)
    loss.backward(retain_graph=True)

optimizer_encoder_1.step()
optimizer_encoder_2.step()
  1. Implementation 3 (similar to Multiple model.forward followed by one loss.backward): (pseudocode)
optimizer.zero_grad()

all_anchors_embed, all_pos_embed, all_neg_embed = bi_encoder.forward(all_anchors_input, all_pos_input, all_neg_input)

for b_anchors_embed, b_pos_embed, b_neg_embed in zip(all_anchors_embed, all_pos_embed, all_neg_embed):
    loss = NLL(b_anchors_embed, b_pos_embed, b_neg_embed)
    loss.backward(retain_graph=True)

optimizer.step()

Which one do you think will fit more negatives in memory?

Take in mind that the bi-encoder in 1) and 3) will have two encoders inside, one to encode the anchors and another one to encode the positives and negatives.