Adding avergaed tensor to a single channel in another tensor

I have a tensor of size b,c,w,h and another tensor of the same size. both of them are encoded features as a result of previous conv layers.

the first tensor is avergaed to have the size b,1,w,h. I want to add this tensor to the second one but in a single channel only. what I mean

x1 = torch.mean(x1, dim=1, keepdim=True)
x2[:, 0:1] += x1

produced the error if in place operation. " one of the variables needed for gradient computation has been modified by an inplace operation"

the two tensors must keep gradient flow.

Hi Falmasri!


x2 = x2[:, 0:1] + x1
# or  x3 = x2[:, 0:1] + x1

Some comments:

x2[:, 0:1] += x1 does modify x2 inplace, because you are indexing
into x2.

This is not necessarily a problem, but it can break the computation
graph – as appears to be happening in your use case – depending
upon what else is going on in your forward pass.

The solution is to not modify the tensor in question inplace.

Note, x2 = x2[:, 0:1] + x1 does not modify the original tensor to
which the python variable x2 referred. Instead, it creates a new tensor
and then sets the python variable x2 to refer to this new tensor. (If
autograd’s computation graph needs the original x2 tensor for the
backward pass, autograd will keep a separate reference to it so that
it won’t get garbage-collected when x2 no longer refers to it.)


K. Frank

Thank you KFrank, but this doesn’t answer my question. Instead it creates a new tensor. I want to update a single channel value inside the tensor x2.

Hi Falmasri!

Sometimes autograd needs to reuse tensors from the forward pass in
the backward pass. If you modify such a tensor inplace, you will break
the computation graph and get the error message you quoted, “one of
the variables needed for gradient computation has been modified by an
inplace operation.”

You have a choice: You can “update a single channel value inside the
tensor x2” and break the computation graph, or you can “create a new
tensor” and have the backward pass work.

That’s just how pytorch works.


K. Frank

I solved indirectly, but there is better way

x1 = torch.mean(x1, dim=1, keepdim=True)
z = torch.zeros(m.shape[0], l, m.shape[2], m.shape[3]).to(m.device)
z[:, i:i+1] = x1
x2 = x2 + z