Effect of computational graph construction in adversarial domain adaptation

(Andreas Triantafyllopoulos) #1

My question is related to the implementation of DANN (https://arxiv.org/pdf/1409.7495.pdf) in PyTorch. To reproduce the results in this paper, I needed three neural networks, namely one feature extractor that is simultaneously connected to the task-specific model and to the domain predictor (via a Gradient Reversal layer).

I wanted to measure a baseline for my implementation without Domain Adaptation, so I tried setting the lambda parameter mentioned in the paper to 0, effectively disabling domain adaptation, but the results where much different than when using a single Neural Network for classification, so I decided to investigate.

I managed to narrow down my problem to a single line of code.

A minimal example is as follows:

# initialize feature extractor and task-specific model
feature_model = [...]
task_model = [...]

# jointly optimize feature extractor and task model
optimizer = torch.optim.SGD([{'params': feature_model.parameters()},
                                         {'params': task_model.parameters()}],
                                        lr=0.01, momentum=0.9)

for i, (data_s, data_t) in enumerate(zip(self.source_loader, self.target_loader)):  # iterate over source and domain data
            # get source and target features
            s_features = data_s['features'].float().cuda()
            t_features = data_t['features'].float().cuda()

            # get feature embeddings using feature model
            s_embedded_features, _ = self.feature_model(s_features)

            # The following line is where my problem is located
            t_embedded_features, _ = self.feature_model(t_features)

            self.optimizer.zero_grad()  # reset optimizer's state
            outputs, _ = self.task_model(s_embedded_features)  # get task predictions
            labels = Variable(data_s['class'].long().cuda())  # get task labels
            task_loss = F.nll_loss(outputs, labels) # compute task loss

            task_loss.backward()  # back-propagate task loss
            self.optimizer.step()  # update models

When commenting the line where I pass target features through my feature extractor networks (i.e. this one: t_embedded_features, _ = self.feature_model(t_features)) the results are what I would expect (i.e. almost identical to those obtained using a single neural network).

After uncommenting this line, however, the classification results are completely different (accuracy is increased by 10%). This is actually the same improvement I saw after enabling domain adaptation.

I think my problem is related to how the computational graph is constructed. I specifically suspect that passing the target features through the feature extractor somehow affects the way the gradient is computed, but I fear that I’m missing some fundamental understanding of how the computational graph is constructed.

I could really use some help in getting this to work. Ideally, I would like to fix the problem mentioned in my code above, so I could be certain that whatever improvements are actually a result of domain adaptation, and not somehow training on the target set itself.

P.S. I tried setting t_features = Variable(data_t[‘features’].float().cuda(), requires_grad=False) but that didn’t work.

UPDATE: I managed to further isolate my problem in the presence of BatchNormalization in my feature extractor. Specifically that network is structured as follows: (Dropout(Relu(BatchNormalization(Linear)))) (three layers of the same format). Removing BatchNormalization from all layers fixes my issue, but my question still stands. I would really like to know why it affects my results in this way.

(John1231983) #2

Does it solve your problem by looking at [Solved] Reverse gradients in backward pass?

(Andreas Triantafyllopoulos) #3

This is the post I based my code on (though I later changed the three optimizer’s to one). I have managed to further isolate my problem in the presence of BatchNormalization in my feature extractor (see original post for update)