Autograd.grad() for Tensor

I want to compute the gradient between two tensors in a net. The input X tensor is sent through a set of convolutional layers which give me back and output Z tensor.

I’m creating a new loss and I would like to know the MSE between gradient of norm(Y) w.r.t. each element of X. Here the code:

# Staring tensors
X = torch.rand(40, requires_grad=True)
Y = torch.rand(40, requires_grad=True)

# Define loss
loss_fn = nn.MSELoss()

#Make some calculations
V = Y*X+2

# Compute the norm
V_norm = V.norm()

# Computing gradient to calculate the loss
for i in range(len(V)):
    if i == 0:
        grad_tensor = torch.autograd.grad(outputs=V_norm, inputs=X[i])
    else:
        grad_tensor_ = torch.autograd.grad(outputs=V_norm, inputs=X[i])
        grad_tensor = torch.cat((grad_tensor, grad_tensor_), dim=0)
        
# Grund truth
gt = grad_tensor * 0 + 1

#Loss
loss_g = loss_fn(grad_tensor, gt)
print(loss_g) 

Unfortunately, I’ve been making tests with torch.autograd.grad(), but I could not figure out how to do it. I get teh following error: RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Setting allow_unused=True gives me back None which is not an option. Not sure how to compute the loss between the gradients and the norm. Any idea about how to code this loss?

Hi,

I think the problem is that you forgot that indexing is a proper operation. And so X[i] returns a different Tensor than X, and X[i] was not used to compute the norm, hence the error.
If you give X as input, you should get the gradients for all X and then be able to access X.grad[i].

1 Like

That was helpful. Thanks. Another question, related to if this a part of the final loss, with more criterion. Currently I’m doing: loss = criterion_1 + loss_g and then loss.backward(). Is this correct? Or it is twice propagated? one in torch.autograd.grad(outputs=V_norm, inputs=X, retain_graph=True) and the other in loss.backward(). Thanks.

Yes this is correct. You will get the gradient corresponding to the sum of your loss, which is the sum of the gradients in this case.

@ptrblck
I try to get the outputs from several models, stack it together and feed into another classifier, but I get the following error:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Below is the forward method of my model:

def forward_model(self, x, member_id=None):
        mem_outpus = []
        for i in range(5):
            output = self.models[i](x)
            mem_outpus.append(output)               
                
        ensemle_inputs = torch.stack(mem_outpus, dim=1).unsqueeze(dim=1)
        ensemle_inputs = (ensemle_inputs - self.mu)/self.std
        ensemle_outputs = self.classifier(ensemle_inputs)
        return ensemle_inputs

Could you post the code you are using to calculate the gradients?
I guess this error might be raised by the direct gradient calculation via autograd.grad?

Yes I use autograd.grad.

Here is the code:

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
## load mnist dataset
use_cuda = torch.cuda.is_available()

root = './data'
if not os.path.exists(root):
    os.mkdir(root)
    
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
# if not exist, download mnist dataset
train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

train_args =  {
            "epsilon": 0.3,
            "step_size": 0.1,
            "num_steps": 10}

## network

class Net_Ensemble(nn.Module):
    def __init__(self, num_ens, num_classes):
        super(Net_Ensemble, self).__init__()
        self.num_feature = num_ens*num_classes*10
        self.num_classes = num_classes
        self.conv_1 = nn.Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1))
        self.conv_2 = nn.Conv2d(4, 10, kernel_size=(1, 1), stride=(1, 1))
        self.fc_1 = nn.Linear(in_features=self.num_feature, out_features=100, bias=True)
        self.fc_2 = nn.Linear(in_features=100, out_features=10, bias=True)
        self.fc_3 = nn.Linear(in_features=200, out_features=100, bias=True)
        self.fc_4 = nn.Linear(in_features=100, out_features=10, bias=True)

        self.num_steps = train_args["num_steps"]
        self.epsilon = train_args["epsilon"]
        self.step_size = train_args["step_size"]
        self.rand = False

    def forward_model(self, x):
        #x = -x-280
        x = F.relu(self.conv_1(x))
        x = F.relu(self.conv_2(x))
        x = x.view(-1, self.num_feature)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

    def forward(self, x):
        outputs = self.forward_model(x.detach())    
        return outputs



class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def name(self):
        return "LeNet"


class model_combined(nn.Module):
    def __init__(self, netlist_models, voting_net, mu, std):
        super(model_combined, self).__init__()
        self.models = nn.ModuleList(netlist_models)       
        self.classifier = voting_net
        self.mu = nn.Parameter(mu, requires_grad=False)
        self.std = nn.Parameter(std, requires_grad=False)

        self.num_steps = train_args["num_steps"]
        self.epsilon = train_args["epsilon"]
        self.step_size = train_args["step_size"]
        self.rand = False

    def forward_model(self, x, member_id=None):
        outputs = []
        if member_id is None:
            range_members = range(0, len(self.models))
        else:
            range_members = range(member_id, member_id+1)
        mem_outpus = []
        for i in range_members:
            output = self.models[i](x)
            mem_outpus.append(output)               
                
        ensemle_inputs = -torch.stack(mem_outpus, dim=1).unsqueeze(dim=1)
        ensemle_inputs = (ensemle_inputs - self.mu)/self.std
        ensemle_outputs = self.classifier(ensemle_inputs)
        return ensemle_outputs

    def forward(self, x, target=None, make_adv=False):

      
      if make_adv:

        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
          
        prev_training = bool(self.training)
        self.eval()        
        
        inputs = x.clone()

        for i in range(self.num_steps):    
           x = x.clone().detach().requires_grad_(True)
           outputs = self.forward_model(x) 
           losses = criterion(outputs, target)
           loss = torch.mean(losses) 
           grad, = torch.autograd.grad(loss, [x]) 
           with torch.no_grad():
              step = torch.sign(grad) * self.step_size
              diff = x + step - inputs
              diff = torch.clamp(diff, -self.epsilon, self.epsilon)
              x = torch.clamp(diff + inputs, 0, 1)
        
        if prev_training:
           self.train()    
      
      outputs = self.forward_model(x.detach())    
      return outputs

Net = Net_Ensemble(1, 10)

mu = torch.tensor([0.0])
std = torch.tensor([1.0])

## training
models = [LeNet()]
model = model_combined(models, Net, mu, std)

if use_cuda:
    model = model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    # trainning
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimizer.zero_grad()
        if use_cuda:
            x, target = x.cuda(), target.cuda()
        x, target = Variable(x), Variable(target)
        out = model(x, target=target, make_adv=True)
        loss = criterion(out, target)
        ave_loss = ave_loss * 0.9 + loss.data * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx+1) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss))

Thanks for the code.
You are detaching x in Net_Ensemble, which is used as self.classifier in model_combined.forward_model. Detaching the activation will cut the computation graph, such that Autograd won’t be able to backpropagate to the original input anymore.
After I’ve removed the .detach() call, the error is gone (I’ve used dummy tensors to test it, not the complete script).