I am working on building a DANN (Ganin et al. 2016) in PyTorch. This model is used for domain adaptation, and forces a classifier to only learn features that exist in two different domains, for the purpose of generalization across these domains. The DANN uses a Gradient Reversal layer to achieve this.
I have seen some suggestions on this forum on how to modify gradients manually. However, I found it difficult to apply in my case, as the gradients are reversed midway through the backward pass (see the image, the gradients are reversed once the backward pass reaches the feature extractor, through the GRL (grad reverse layer)).
Below is my take on this, though I am not sure if I have used the hook correctly. I would greatly appreciate some suggestions on how to use these hooks effectively!
# this is the hook to reverse the gradients def grad_reverse(grad): return grad.clone() * -lambd lambd = 1 # 2) train feature_extractor and domain_classifier on full batch x # reset gradients f_ext.zero_grad() d_clf.zero_grad() # calculate domain_classifier predictions on batch x d_out = d_clf(f_ext(x).view(batch_size, -1)) # use normal gradients to optimize domain_classifier f_d_loss = d_crit(d_out, yd.float()) f_d_loss.backward(retain_variables = True) d_optimizer.step() # use reversed gradients to optimize feature_extractor d_out.register_hook(grad_reverse) f_d_loss = d_crit(d_out, yd.float()) f_d_loss.backward(retain_variables = True) f_optimizer.step()