Calculate gradient of output w.r.t input before adding noise

Hi all,
Suppose my my input img is processed by adding noise (noisy_img) before feed into model, when I tried gradients = autograd.grad(outputs=output, inputs=img) I can’t get the gradient. But if I use gradients = autograd.grad(outputs=output, inputs=noisy_img) , it seems working without error. I set both img and noisy_img requires_grad=True.
What’s the proper way to get the gradients of output w.r.t. img in this case?
Thank you.

Hi,

You need to make sure that you set the requires_grad field before using the img to compute the noisy_img.
If you do so, can you share the code that generates the noisy image as well as the error you see?

Hi albanD,
Here is a simple DAE code

import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import torch.autograd as autograd
import numpy as np
np.random.seed(42)
torch.manual_seed(42)

# ref: https://github.com/ReyhaneAskari/pytorch_experiments/blob/master/DAE.py

num_epochs = 20
batch_size = 128 
learning_rate = 1e-3
img_transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = MNIST('../data', transform=img_transform, download=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

def add_noise(img):
    mean=0.
    std=1
    noise = torch.randn(img.size()) * std + mean
    noisy_img = img + noise.cuda()
    return noisy_img

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(True),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

def calc_gradient_penalty(output, img):
    

    gradients = autograd.grad(outputs=output, inputs=img,
                              grad_outputs=torch.ones(output.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 
   
    return gradient_penalty

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).cuda()
        img = autograd.Variable(img, requires_grad=True)
       
        noisy_img = add_noise(img)
        noisy_img = Variable(noisy_img).cuda()
        
        # ===================forward=====================
        output = model(noisy_img)
        gradient_penalty = calc_gradient_penalty(output, img)
        print(gradient_penalty)
        
        img.requires_grad = False
        loss = criterion(output, img)
        loss = loss + gradient_penalty

        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

The error I got before print is RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
Thank you.

Hi,

You should remove all the calls to Variable and that should fix it:

        img = img.view(img.size(0), -1)
        img = img.cuda().requires_grad_()
       
        noisy_img = add_noise(img)

Hi,
Thank you. It works. Can you explain why img = autograd.Variable(img, requires_grad=True) not working in this case? It still sets requires_grad=True right?
Thanks for your time.

The problem is with Variable(noisy_img) I think that effectively does the same as noisy_img.detach() and so breaks the graph.

2 Likes