Running multiple batches trough model before backpropagation

In my code I’m trying to implement the semi-supervised learning method FixMatch. However, I’m not sure if I implemented it correctly. Can I first run the labeled data and then the unlabeled data trough the model, or will this ruin my gradient computation? In another implementation I saw that they combine both to yield a single input.

In my code, input_la is a batch of labeled samples, input_ul_weak a batch of weakly augmented samples and input_ul_strong a batch of strongly augmented samples.

# Get model predictions on labeled data
  out_la = model(input_la)

  # We don't need the gradient here since we are only using the pseudo labels
  with torch.no_grad():
    out_ul_weak = model(input_ul_weak)
  
  out_ul_strong = model(input_ul_strong)

  # Get pseudo label from weakly augmented sample
  conf, pseudo_label = F.softmax(out_ul_weak, dim=1).max(axis=1)
  
  # Only use confident pseudo labels
  mask = conf > threshold

  # Losses
  loss_la = F.cross_entropy(out_la, label_la)
  loss_ul = F.cross_entropy(out_ul_strong[mask], pseudo_label[mask])

  loss_combined = loss_la + lambda_u * loss_ul

  model.zero_grad()
  loss_combined.backward()
  optimizer.step()

Your can execute multiple forward passes, calculate different losses, combine them, and calculate the gradients using the accumulated loss, so your code looks alright.
Are you seeing any unexpected behavior using this approach?

I get a somewhat weird behavior. If I train the code without the unlabeled data

out_la = model(input_la)
loss_la = F.cross_entropy(out_la, label_la)

model.zero_grad()
loss_la.backward()
optimizer.step()

it converges pretty fast. If I use my original code instead and set the threshold to 1, it obviousely uses no unlabeled data since the mask is False everywhere. What surprises me is that even in this case it does not really converge and acts totally differently than the purely supervised code.

In my other question I tried finding out if its a masking issue, but I’m not sure.