Chunk + cat data assignment does not propogate

Hello,
I am breaking a tensor into chunks so that I can individually change the requires_grad argument based on which values I want to optimize

``````ones = torch.ones(3)
chunks = torch.chunk(ones, 3, 0)
# using chunks as input to optimizer
k = torch.cat(chunks)
print(k)
Output: tensor([1., 1., 1.])
k[0]=2
print(k)
Output: tensor([2., 1., 1.])

print(chunks)
Output: (tensor([1.]), tensor([1.]), tensor([1.]))
``````

I am doing this so that I don’t have to change most of my code, so I am first using the chunk function to break the tensor and then using cat function. Changing the values of variable ‘k’ doesn’t seem to reflect on variable ‘chunks’. I don’t understand where the flaw in my intuition is.

Thank you.

`torch.chunk` creates views of the original `ones` tensor. `torch.cat` does not create a view; it creates a branch new tensor by copying data. This is why modifying the output of `torch.cat` (k) doesn’t change `chunks`; their data storage is completely unrelated.

1 Like

I am trying to understand the difference better

``````x = torch.cat([x1, x2, x3])
out = f(x)
``````

would the gradients propage to x1, x2, x3 as expected or do they not get any gradients.

Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?

Thank you.

The gradients will propagate as long as `x1, x2, ...` require grad.

Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?

You could use `x = torch.cat([x1, x2, x3])`, apply some in-place operations, and then split it back up via `y1, y2, y3 = x.chunk(3)`.

1 Like