Hi all, I’m trying to apply multi-task learning on semi-supervised data. I have a very small labeled training set compared to the unlabeled dataset (256 labeled vs 2421184 unlabeled samples). I need to find a way to iterate through the dataloaders simultaneously.
My question is similar to this one but making dataloaders same size via adjusting their batch sizes trick is not applicable for my problem. Even if I set the batch size 1 for the labeled dataset loader, batch size of the unlabeled loader should be 2421184//256=9457
which is too large to fit into the memory.
In this case, would it be proper to accumulate unlabeled losses in a loop to emulate same dataloader size (example below) ? If not, what would be the proper way of it ?
Note: The main task is regression and auxiliary one is classification, so only labeled samples have regression values and all samples (labeled and unlabeled) have classification labels.
labeled_dataset = CustomDataset(learning='labeled')
unlabeled_dataset = CustomDataset(learning='unlabeled')
dataset_coeff = len(unlabeled_dataset) // len(labeled_dataset)
loader_params = {'batch_size': 32, 'shuffle': True, 'num_workers': 4}
labeled_train_loader = DataLoader(labeled_dataset, **loader_params)
unlabeled_train_loader = DataLoader(unlabeled_dataset, **loader_params)
num_epochs = 100
for e in range(num_epochs):
labeled_iter, unlabeled_iter = iter(labeled_train_loader), iter(unlabeled_train_loader)
for batch_id in range(len(labeled_train_loader)):
labeled_patch, labeled_reg_target, labeled_class_target = next(labeled_iter)
labeled_reg_pred, labeled_class_pred = model(labeled_patch) # Makes two predictions, regression and classification
labeled_reg_loss = torch.nn.MSELoss(input=labeled_reg_pred, target=labeled_reg_target)
labeled_class_loss = torch.nn.CrossEntropyLoss(input=labeled_class_pred, target=labeled_class_target)
total_unlabeled_batch_loss = 0
for uid in range(dataset_coeff): # Emulate same dataloader size
unlabeled_patch, _, unlabeled_class_target = next(unlabeled_iter)
_, unlabeled_class_pred = model(unlabeled_patch) # Unlabeled samples don't have regression values, so only classification loss is calculated.
unlabeled_class_loss = torch.nn.CrossEntropyLoss(input=unlabeled_class_pred, target=unlabeled_class_target)
total_unlabeled_batch_loss = total_unlabeled_batch_loss + unlabeled_class_loss
final_loss = labeled_reg_loss + labeled_class_loss + total_unlabeled_batch_loss / dataset_coeff # Used dataset_coeff as weight for unlabeled_loss since it would overweight the final loss due to accumulation.
final_loss.backward()