Chunk + cat data assignment does not propogate

The gradients will propagate as long as x1, x2, ... require grad.

Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?

You could use x = torch.cat([x1, x2, x3]), apply some in-place operations, and then split it back up via y1, y2, y3 = x.chunk(3).

1 Like