Thank for the reply @richard
I am trying to understand the difference better
x = torch.cat([x1, x2, x3])
out = f(x)
would the gradients propage to x1, x2, x3 as expected or do they not get any gradients.
Also is there a good way of merging all the chunks together, so that I can perform operations on them as if they are a single tensor?
Thank you.