The gradients will propagate as long as x1, x2, ...
require grad.
Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?
You could use x = torch.cat([x1, x2, x3])
, apply some in-place operations, and then split it back up via y1, y2, y3 = x.chunk(3)
.