Gradient Reversal on data of a specific class

Hi I’m new to pytorch and trying to implement domain adversarial learning with the help of a gradient reverse layer into my network. Suppose my main task is binary classification (positive or negative label), in addition to this I have a domain classifier. Both my binary and domain classifier operate on features produced by the feature extractor. However, my domain classifier only operates on data corresponding to the positive class.
My question is whether the domain loss and binary loss given in my below code actually affect the feature extractor? ( The aim is to produce more accurate and at the same time domain invariant fetaures )

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Define the feature extractor
class FeatureExtractor(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FeatureExtractor, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        return x

# Define the binary classifier
class BinaryClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

# Define the domain classifier
class DomainClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DomainClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

# Define the gradient reversal layer
class GradientReverseLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Hyperparameters
input_size = 10
hidden_size = 20
output_size = 1
domain_output_size = 2
learning_rate = 0.001
num_epochs = 10

# Dummy data
data = torch.randn(100, input_size)
labels = torch.randint(0, 2, (100,))
domain_labels = torch.tensor([[1, 0] if label == 1 else [0, 1] for label in labels], dtype=torch.float32)

# Initialize models, criterion, and optimizers
feature_extractor = FeatureExtractor(input_size, hidden_size)
binary_classifier = BinaryClassifier(hidden_size, hidden_size, output_size)
domain_classifier = DomainClassifier(hidden_size, hidden_size, domain_output_size)
criterion = nn.BCELoss()
binary_optimizer = optim.Adam(binary_classifier.parameters(), lr=learning_rate)
domain_optimizer = optim.Adam(domain_classifier.parameters(), lr=learning_rate)

# Train loop
for epoch in range(num_epochs):
    for input_data, label, domain_label in zip(data, labels, domain_labels):
        # Forward pass through feature extractor
        features = feature_extractor(input_data)

        # Forward pass through binary classifier
        binary_output = binary_classifier(features)

        # Forward pass through gradient reversal layer and domain classifier for positive label samples
        if label == 1:
            reversed_features = GradientReverseLayer.apply(features)
            domain_output = domain_classifier(reversed_features)

            # Compute domain classification loss
            domain_loss = criterion(domain_output, domain_label.unsqueeze(0))

            # Backpropagation for domain classifier
            domain_optimizer.zero_grad()
            domain_loss.backward()
            domain_optimizer.step()

        # Compute binary classification loss
        binary_loss = criterion(binary_output, label.unsqueeze(0).float())

        # Backpropagation for binary classifier
        binary_optimizer.zero_grad()
        binary_loss.backward()
        binary_optimizer.step()

I’m unsure if I misunderstand the question, but since domain_loss.backward() will only be called if label == 1, it won’t have any effect on the feature_extractor in other cases.Let me know, if I misunderstood the question.

1 Like

First of all thankyou so much for your reply! Yes that is what I intend to do. I want the feature extractor to only produce domain invariant features for the positive class (label 1). I was just having trouble understanding how backpropagation works in this case. So in conclusion even though I am not directly calling the feature extractor class in the domain classifier, but instead using the output of feature extractor as input to domain classifier in the training loop, the optimization of domain loss affects the feature extractor as well?
Additionally I have another question, I only want to use the binary classifier during testing. So how exactly do I save my model after training? Will I need to save the feature extractor and binary classifier separately and call them in the same way as above during testing? Am I doing this correctly?