Chunk + cat data assignment does not propogate

Hello,
I am breaking a tensor into chunks so that I can individually change the requires_grad argument based on which values I want to optimize

ones = torch.ones(3)
chunks = torch.chunk(ones, 3, 0)
# using chunks as input to optimizer
k = torch.cat(chunks)
print(k)
Output: tensor([1., 1., 1.])
k[0]=2
print(k)
Output: tensor([2., 1., 1.])

print(chunks)
Output: (tensor([1.]), tensor([1.]), tensor([1.]))

I am doing this so that I don’t have to change most of my code, so I am first using the chunk function to break the tensor and then using cat function. Changing the values of variable ‘k’ doesn’t seem to reflect on variable ‘chunks’. I don’t understand where the flaw in my intuition is.

Thank you.

torch.chunk creates views of the original ones tensor. torch.cat does not create a view; it creates a branch new tensor by copying data. This is why modifying the output of torch.cat (k) doesn’t change chunks; their data storage is completely unrelated.

1 Like

Thank for the reply @richard
I am trying to understand the difference better

x = torch.cat([x1, x2, x3])
out = f(x)

would the gradients propage to x1, x2, x3 as expected or do they not get any gradients.

Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?

Thank you.

The gradients will propagate as long as x1, x2, ... require grad.

Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?

You could use x = torch.cat([x1, x2, x3]), apply some in-place operations, and then split it back up via y1, y2, y3 = x.chunk(3).

1 Like