In my code I’m trying to implement the semi-supervised learning method FixMatch. However, I’m not sure if I implemented it correctly. Can I first run the labeled data and then the unlabeled data trough the model, or will this ruin my gradient computation? In another implementation I saw that they combine both to yield a single input.
In my code, input_la
is a batch of labeled samples, input_ul_weak
a batch of weakly augmented samples and input_ul_strong
a batch of strongly augmented samples.
# Get model predictions on labeled data
out_la = model(input_la)
# We don't need the gradient here since we are only using the pseudo labels
with torch.no_grad():
out_ul_weak = model(input_ul_weak)
out_ul_strong = model(input_ul_strong)
# Get pseudo label from weakly augmented sample
conf, pseudo_label = F.softmax(out_ul_weak, dim=1).max(axis=1)
# Only use confident pseudo labels
mask = conf > threshold
# Losses
loss_la = F.cross_entropy(out_la, label_la)
loss_ul = F.cross_entropy(out_ul_strong[mask], pseudo_label[mask])
loss_combined = loss_la + lambda_u * loss_ul
model.zero_grad()
loss_combined.backward()
optimizer.step()