One of the variables needed for gradient computation has been modified by an inplace operation (Changing patch inside an image during training)

I am trying to do the following:

  1. Pass an image to a neural network and predict a patch and it’s location
  2. Apply the patch on the image (Change the intensities of image by values of patch at the predicted coordinates)
  3. Pass this modified image to a pre-trained classifier model which is frozen (model.eval())
  4. I want to basically create a patch on an image at some location which would cause adversarial attack.

Below are the chunks of code which I am using for above steps:

class PatchModel(nn.Module):
    def __init__(self, sz):
        super().__init__()
        self.sz = sz
        layers = [Flatten(), nn.Linear(784, 100), nn.ReLU(inplace=True), nn.Linear(100, sz * sz + 2)]
        self.m = nn.Sequential(*layers)
    def forward(self, x):
        img_sz = x.size()[2:]
        o = self.m(x)
        o[:, :2] = torch.clamp(o[:, :2], 0, img_sz[0] - self.sz) 
        return o

This is the model that takes in image and predicts patch and it’s coordinates

def apply_to_img(img, op):
    patch = op[:, 2:].view(1, 1, 4, 4)
    coord = op[:, :2].int()
    _temp = torch.zeros(*img.size())
    _temp[0, 0, coord[0, 0]:coord[0, 0] + 4, coord[0, 1]:coord[0, 1] + 4] = patch
    # Adding patch to image
    img = img + _temp.clone()
    return img

I defined an function which takes image and output of the first neural network and adds the patch to the image.

def epoch(loader, model, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X, y
        # Pass the image into Patch model
        yp_ = patch_model(X)
        # Use the apply function to apply patch to image
        yp = apply_to_img(X, yp_)
        # This is frozen pretrained model for which I want to create advesarial patch.
        yp = model(yp)
        loss = -nn.CrossEntropyLoss()(yp,y) + nn.CrossEntropyLoss()(yp, torch.LongTensor([3]))
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

This is main training loop
I passes Adam optimizer to the above function:
opt = AdamW(patch_model.parameters(), 0.001)
I am getting the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-150-e28149387ca7> in <module>
----> 1 epoch(train_loader, model.eval(), opt)

<ipython-input-149-05b573ec2ee5> in epoch(loader, model, opt)
     10         if opt:
     11             opt.zero_grad()
---> 12             loss.backward()
     13             opt.step()
     14 

/usr/local/lib/python3.6/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    100                 products. Defaults to ``False``.
    101         """
--> 102         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    103 
    104     def register_hook(self, hook):

/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

What can I do to prevent this error? Also, the error goes away if in the apply_to_img function I detach the output of my PatchModel network but then, as you would expect, the PatchModel parameters don’t get updated. It remains same after running epoch.

I am using MNIST data with batch size 1 hence I directly hardcoded the indices in apply_to_img function

Your inplace operation seems to be o[:, :2] = torch.clamp(...). You could use o = torch.cat([o[:, :2].clamp(...), o[:, 2:]], dim=1) instead.

Best regards

Thomas

Hi, thanks for helping me out. The code is running now. But I am facing another difficulty, that my patch_model is not changing it’s parameters. On further inspection, the grad of all parameters is 0 after running the epoch. I tried removing .clone from _temp in my apply_to_img but the result is same. Any insight regarding this would be very much helpful. One more change I made in apply_to_patch was I clamped my patch from (0, 1) so that the values are in the range of pixel values.

For reference in the epoch function, I am calling it this way:
epoch(train_loader, model.eval(), opt)
where model is pretrained model for MNIST

The patch is probably mostly 0, 1, after clamping, right?
If that is the case then the gradient is 0, as changing the the clamped value a little bit would not change the output. This is why many nets use squashing to 0…1, e.g. with sigmoid, instead, to move their outputs into the desired range.

Best regards

Thomas

Thanks a lot. I tried using sigmoid for squashing the patch values from 0 to 1 and now I am successfully able to attack test data with 68% accuracy. But the problem is, the patch values even after using sigmoid are either 0 or 1. And the coordinate prediction is always being done in the top left part of each image. I am attaching a sample output that my model is predicting and the patch is in almost same place for all images.

.
For values I am thinking about trying to add norm of patch to my loss function to keep it’s value low.

That is perhaps something about your method? I find that not terribly surprising if you optimize, you’ll have some gradient and then you push all things with one sign in the gradient towards 0 and the other towards 1. That is why the adversarial attacs usually say how much far they want to go in some norm (e.g. l-infinity aka max-norm for FGSA or l² or l¹ norm for some others).

That probably is related to how you generate your patch coordinates.

Best regards

Thomas

1 Like