Computing gradient inside torch.autograd.Function

Hi,

I am trying to implement a custom loss function that involves computing an inner optimization loop. I am using a torch.autograd.Function because I would like to manually compute the gradient.
The function looks something like this:

# Inside the forward method of the Function
output_theta = model(X)
theta_star = {}
theta_inital = deepcopy(model.state_dict())
for name, v in theta_inital.items():
    random_val = torch.randn(size=v.shape, device=device)
    theta_star[name] = torch.tensor(v + v.abs() * initial_std * random_val, requires_grad=True)

for _ in range(n_attack_steps):
    for name, v in model.named_parameters():
        v.copy_(theta_star[name])
    output_theta_star = model(X)
    step_loss = robustness_loss(torch.nn.functional.softmax(output_theta, dim=1).log(), output_theta_star)
    step_loss.backward()

The implementation is not done, but the problem that I now have is that I can’t get any gradients on theta_star.
Is it generally not possible to calculate gradients inside a Function?
I have the feeling that using a Function is not the way to go here, but I need to implement the gradient in a custom fashion.
As a reference, I want to implement the loss function of this paper: https://arxiv.org/pdf/2106.05009.pdf
What do I do?

Hey!

Would mind sharing what exactly X is / where is it calculated? Or is it just arbitrary input data?

From what I can see, you’re basically taking your network and modifying the parameters and creating a copy, then comparing the difference in output via the original model and a copy of the model with the modified parameters? Is this correct?

When copying across you use copy_ which is an in-place copy. Perhaps, it’d be better to deepcopy the model rather than the model’s state_dict and modifying the parameters in that way all in one go rather than two separate loops. Something like this might work!

model_star = deepcopy(model)
for model_np, model_star_np  in  zip(model.named_parameters(), model_star.named_parameters()):
  name, param = model_np                #unpack both tuples (name and param) 
  name_star, param_star = model_star_np #for each model
 
  random_val = torch.randn(*param_star.shape, device=device, requires_grad=True)
  param_star.copy_( v + v.abs() * initial_std * random_val) #in-place copy across to model_star

#calculate both outputs here
output_theta = model(X)
output_theta_star = model_star(X)

#Calc. loss here
step_loss = robustness_loss(torch.nn.functional.softmax(output_theta, dim=1).log(), output_theta_star)

optim.zero_grad() #don't forget to zero your grads! (unless you just want the grad of loss w.r.t both model parameters?)
step_loss.backward()
optim.step()

When creating a custom loss function writing a standard python function should be fine. torch.autograd.Function is used when you want to create a custom function with a forward and backward method, for example, a custom activation function or module layer! For example a custom ReLU function example is in the docs here → PyTorch: Defining New autograd Functions — PyTorch Tutorials 1.7.0 documentation

Thanks for the reply. I think this is the way. I know decided to make a class out of it and keep two copies of the model, which are updated every time the loss is called. If you are interested, this is how I did it:

Paste the following into torch_loss.py

import torch
from copy import deepcopy

class AdversarialLoss():

    def __init__(
        self,
        model,
        natural_loss,
        robustness_loss,
        device,
        n_attack_steps,
        mismatch_level,
        initial_std,
        beta_robustness
    ):
        # - Make copies of the models
        self.model_theta = deepcopy(model)
        self.model_theta_star = deepcopy(model)
        self.natural_loss = natural_loss
        self.robustness_loss = robustness_loss
        self.device = device
        self.n_attack_steps = n_attack_steps
        self.mismatch_level = mismatch_level
        self.initial_std = initial_std
        self.beta_robustness = beta_robustness

    def L_rob(
        self,
        output_theta,
        output_theta_star
    ):
        return self.robustness_loss(
            torch.nn.functional.softmax(output_theta_star, dim=1).log(),
            torch.nn.functional.softmax(output_theta, dim=1)
        )


    def _adversarial_loss(
        self,
        model,
        X
    ):

        # - Update the parameters of the "healthy" model
        self.model_theta.load_state_dict(model.state_dict())

        # - Initialize theta* with small gaussian noise
        with torch.no_grad():
            # - f(X,theta)
            output_theta = self.model_theta(X)
            # - Accumulate the signed gradients for the gradient calculation
            sum_signed_grads = {}
            # - Compute theta*
            theta_star = {}
            # - Step size is scaled to each parameter and determines how much the adversary can effect the parameter
            step_size = {}
            # - Store random vals for the gradient computation
            random_val_dict = {}
            for name, v in self.model_theta.named_parameters():
                sum_signed_grads[name] = torch.zeros_like(v, device=self.device)
                # print("!! WARNING Using torch.ones_like as random initial pert.")
                random_val = torch.randn(size=v.shape, device=self.device)
                # random_val = torch.ones_like(v, device=self.device)
                random_val_dict[name] = random_val
                theta_star[name] = v + v.abs() * self.initial_std * random_val
                step_size[name] = (self.mismatch_level * v.abs()) / self.n_attack_steps

        # - PGA attack
        for _ in range(self.n_attack_steps):
            # - Load the initial theta_star
            self.model_theta_star.load_state_dict(theta_star)
            # - Pass input through net with adv. parameters and compute grad of robustness loss
            output_theta_star = self.model_theta_star(X)
            step_loss = self.L_rob(output_theta=output_theta, output_theta_star=output_theta_star)
            step_loss.backward()
            # - Update the sum of the signed gradients
            for name,v in self.model_theta_star.named_parameters():
                sum_signed_grads[name] += v.grad.sign()

            # - Update theta*
            for name,v in self.model_theta_star.named_parameters():
                theta_star[name] = theta_star[name] + step_size[name] * v.grad.sign()
                v.grad = None # - Ensure gradients don't accumulate

            # - After updating theta_star, load  the new weights into the network
            self.model_theta_star.load_state_dict(theta_star)

        # - Calculate d L_rob / d theta* for computing the final gradient
        output_theta_star = self.model_theta_star(X)
        loss_rob = self.L_rob(output_theta=output_theta, output_theta_star=output_theta_star)
        loss_rob.backward()
        grad_L_theta_star = {}
        for name,v in self.model_theta_star.named_parameters():
            grad_L_theta_star[name] = v.grad # - Store the gradients
            v.grad = None
            
        # - The final gradient can be computed using:  d L / d theta* * d theta* / d theta + d L / d theta
        grad_L_theta = {}
        with torch.no_grad():
            output_theta_star = self.model_theta_star(X)
        loss_rob = self.L_rob(output_theta=self.model_theta(X), output_theta_star=output_theta_star)
        loss_rob.backward()
        for name,v in self.model_theta.named_parameters():
            grad_L_theta[name] = v.grad
            v.grad = None


        # - Compute d theta* / d theta which is the Jacobian. J is diagonal so we can just keep the shape.
        # - See https://arxiv.org/abs/2106.05009
        J_diag = { name: (1.0 + v.sign() * (self.initial_std * random_val_dict[name] + \
                    self.mismatch_level / self.n_attack_steps * sum_signed_grads[name])).detach() \
                    for name,v in self.model_theta.named_parameters()}

        # - Final gradient
        final_grad = {name: grad_L_theta_star[name] * J_diag[name] + grad_L_theta[name] for name in J_diag}
        return loss_rob.detach(), final_grad

    def compute_gradient_and_backward(
        self,
        model,
        X,
        y
    ):

        if self.beta_robustness != 0.0:
            # - Get the adversarial loss (note: beta_robustness is not applied yet)
            adv_loss, adv_loss_gradients = self._adversarial_loss(
                model,
                X
            )

            # - Compute the natural loss and backprop
            nat_loss = self.natural_loss(model(X), y)
            nat_loss.backward()

            # - Combine autodiff and numerical gradients
            for name,v in model.named_parameters():
                v.grad.data += self.beta_robustness * adv_loss_gradients[name]
            
            return nat_loss.detach() + self.beta_robustness * adv_loss
        else:
            nat_loss = self.natural_loss(model(X), y)
            nat_loss.backward()
            return nat_loss.detach()

And this is a script that uses it:

# - Deterministic linear layer
from copy import deepcopy
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader

# - Import the adversarial loss
from torch_loss import AdversarialLoss

def eval_test_set(
    test_dataloader,
    net,
):
    net.eval()
    N_correct = 0
    N = 0
    for (X,y) in test_dataloader:
        X, y = X.to(device), y.to(device)
        y_hat = torch.argmax(net(X), axis=1)
        N += len(y)
        N_correct += (y_hat == y).int().sum()
    net.train()
    return N_correct / N

def eval_test_set_mismatch(
    test_dataloader,
    net,
    mismatch,
    n_reps,
    device
):
    net_theta_star = deepcopy(net)
    test_acc_no_noise = eval_test_set(test_dataloader, net)
    test_accs = []
    for idx in range(n_reps):
        print("Test eval. mismatch rob. %d/%d" % (idx,n_reps))
        theta_star = {}
        for name,v in net.named_parameters():
            theta_star[name] = v + v.abs() * mismatch * torch.randn(size=v.shape, device=device)
        net_theta_star.load_state_dict(theta_star)
        test_accs.append(eval_test_set(test_dataloader, net_theta_star))
    return float(test_acc_no_noise), float(sum(test_accs)/len(test_accs))

def init_weights(lyr):
    if isinstance(lyr, (torch.nn.Linear,torch.nn.Conv2d)):
        torch.nn.init.xavier_uniform(lyr.weight)
        lyr.bias.data.fill_(0.01)

class TorchCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, out_channels=64, kernel_size=(4,4), stride=(1,1), padding="same")
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4,4), stride=(1,1), padding="valid")
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.linear1 = torch.nn.Linear(in_features=1600, out_features=256)
        self.linear2 = torch.nn.Linear(in_features=256, out_features=64)
        self.linear3 = torch.nn.Linear(in_features=64, out_features=10)

    def forward(self, inputs):
        x = F.relu(self.conv1(inputs))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 1600)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

if __name__ == '__main__':
    torch.manual_seed(0)
    # - Avoid reprod. issues caused by GPU
    torch.use_deterministic_algorithms(True)
    # - Select device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # - Select which device
    if torch.cuda.device_count() == 2:
        device = "cuda:1"

    # - Fixed parameters
    BATCH_SIZE_TRAIN = 100
    BATCH_SIZE_TEST = 500
    N_EPOCHS = 5
    LR = 1e-4

    base_dir = os.path.dirname(os.path.abspath(__file__))

    download_path = os.path.join(base_dir, "fmnist")
    train_set = torchvision.datasets.FashionMNIST(
        download_path,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.25])])
    )
    test_set = torchvision.datasets.FashionMNIST(
        download_path,
        download=True,
        train=False,
        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.25])])
    )
    train_dataloader = DataLoader(
        dataset=train_set,
        batch_size=BATCH_SIZE_TRAIN,
        shuffle=True,
        num_workers=4
    )
    test_dataloader = DataLoader(
        dataset=test_set,
        batch_size=BATCH_SIZE_TEST,
        shuffle=False,
        num_workers=4
    )

    # - Create Torch network
    cnn = TorchCNN().to(device)
    cnn.apply(init_weights)

    # - Create adam instance for torch
    optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)

    # - Adversarial loss
    adv_loss = AdversarialLoss(
        model=cnn,
        natural_loss=torch.nn.CrossEntropyLoss(reduction="mean"),
        robustness_loss=torch.nn.KLDivLoss(reduction="batchmean"),
        device=device,
        n_attack_steps=10,
        mismatch_level=0.025,
        initial_std=1e-3,
        beta_robustness=0.25
    )

    for epoch_id in range(N_EPOCHS):
        for idx,(X,y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)
            robustness_loss = adv_loss.compute_gradient_and_backward(
                model=cnn,
                X=X,
                y=y
            )

            # - Backward does not need to be called
            # - Update the weights
            optimizer.step()

            # - Zero out the grads of the optimizer
            optimizer.zero_grad()

            if idx % 100 == 0:
                test_acc_no_noise, mean_noisy_test_acc = eval_test_set_mismatch(
                    test_dataloader,
                    cnn,
                    mismatch=0.2,
                    n_reps=5,
                    device=device
                )
                print("\n\nTest acc %.5f Mean noisy test acc %.5f" % (test_acc_no_noise,mean_noisy_test_acc))

            print("Epoch %d Batch %d/%d Loss %.5f" % (epoch_id,idx,len(train_dataloader),robustness_loss))
1 Like