How to replace some elements in a Variable by another in a differentiable way?

If I have a BxCxHxW Variable which has a graph associated with it (It is the output of a CNN) and another similar Variable which is the output of another CNN. Now I want to replace a part of the first by second and pass the new Variable through a third network and then call backward on loss.

Right now, I run into an error because some Variables have been modified in-place. To fix that I am calling clone and then replacing in the third variable one by one, like this:

    def refine(old_rep, fine_rep, ys, xs):
        new_rep = old_rep.clone()
        for i in range(fine_rep.size(0)):
            for j in range(10):
                img_id = i // 10
                y = int(ys[i])
                x = int(xs[i])
               new_rep[img_id,:,y:y+1,x:x+1] = fine_rep[i]
        return new_rep

Is there a faster way to do the same?

Thanks :slight_smile:

1 Like

This is a snippet that works for me from what you described as trying to do:

import torch
from torch.autograd import Variable

x = Variable(torch.randn(10), requires_grad=True)
y = x ** 2
a = Variable(torch.randn(10), requires_grad=True)
b = a ** 2
y[5:] = b[5:]
z = y ** 3
z.sum().backward()

Can you give me a small snippet of the failure case?
I’m wondering if the last layer of your first and second CNN have in-place ReLU or something, which makes them restricted from doing another in-place op on their output. This should be easy to workaound (make the ReLU non in-place).

1 Like

last layer of your first and second CNN have in-place ReLU

Yes! I understood the reason now. I had an inplace ReLU from torchvision’s resnet.

Thanks!