I am trying to reproduce this paper Cross-Modality Person Re-Identification with Generative Adversarial Training in PyTorch
The loss_D is defined as :
criterion_triplet = torch.nn.TripletMarginLoss(margin = 1.4)
criterion_identity = torch.nn.CrossEntropyLoss()
criterion_modality = torch.nn.BCEWithLogitsLoss()
.
.
.
triplet_loss_rgb = criterion_triplet(anchor_rgb_features,
positive_ir_features, negative_ir_features)
triplet_loss_ir = criterion_triplet(anchor_ir_features,
positive_rgb_features, negative_rgb_features)
triplet_loss = triplet_loss_rgb + triplet_loss_ir
predicted_id_rgb = id_classifier(anchor_rgb_features)
predicted_id_ir = id_classifier(anchor_ir_features)
identity_loss = criterion_identity(predicted_id_rgb, anchor_label) + \
criterion_identity(predicted_id_ir, anchor_label)
loss_G = alpha*triplet_loss + beta*identity_loss
# Discriminator
predicted_rgb_modality = mode_classifier(anchor_rgb_features)
predicted_ir_modality = mode_classifier(anchor_ir_features)
loss_D = criterion_modality(predicted_rgb_modality, modality_rgb) + \
criterion_modality(predicted_ir_modality, modality_ir)
The training starts okay but after some time suddenly the discriminator loss goes haywire and after that
Any inputs into what may be happening here will be extremely helpful