In my code I’m trying to implement the semi-supervised learning method FixMatch. However, I’m not sure if I implemented it correctly. Can I first run the labeled data and then the unlabeled data trough the model, or will this ruin my gradient computation? In another implementation I saw that they combine both to yield a single input.
In my code,
input_la is a batch of labeled samples,
input_ul_weak a batch of weakly augmented samples and
input_ul_strong a batch of strongly augmented samples.
# Get model predictions on labeled data out_la = model(input_la) # We don't need the gradient here since we are only using the pseudo labels with torch.no_grad(): out_ul_weak = model(input_ul_weak) out_ul_strong = model(input_ul_strong) # Get pseudo label from weakly augmented sample conf, pseudo_label = F.softmax(out_ul_weak, dim=1).max(axis=1) # Only use confident pseudo labels mask = conf > threshold # Losses loss_la = F.cross_entropy(out_la, label_la) loss_ul = F.cross_entropy(out_ul_strong[mask], pseudo_label[mask]) loss_combined = loss_la + lambda_u * loss_ul model.zero_grad() loss_combined.backward() optimizer.step()