# If I don't use `torch.chunk` will I have backprop errors?

I currently have the following tensor:

``````torch.Size([2, 2, 128, 1, 1])
``````

Normally the batch dimension is at the start, but here I’ve split the tensor on the number of channels. I.e.,

``````(2, 512) => (2, 2, 128, 1, 1)
``````
• 1st dimension is the “split” dimension
• 2nd dimension is batch dimension
• 3rd dimension is # of channels

Right now, I do the following:

``````        cond_w, cond_b = self.emb_layers(emb)
assert cond_w.shape == (N, self.out_channels, 1, 1)
x = self.out_norm(x) * (1 + cond_w) + cond_b
``````

so I just index into the returned tensor and split it into 2 child tensors.
I’m wondering if this will effect backprop, as I never see people do this.

The alternative is:

``````        scale_shift = self.emb_layers(emb)
cond_w, cond_b = th.chunk(scale_shift, 2, dim=1)
cond_w = cond_w.squeeze(1)
cond_b = cond_b.squeeze(1)
assert cond_w.shape == cond_b.shape
assert cond_w.shape == (N, self.out_channels, 1, 1)
``````

is one method preferable to the alternative?

`torch.chunk` will not detach the tensors from the computation graph and will yield the same results as seen here:

``````x = torch.randn(10, 3, 224, 224, requires_grad=True)
model = models.resnet18().eval()

# full batch
out = model(x)
out.mean().backward()
# tensor(1.1471)

# chunk
x1, x2 = torch.chunk(x, 2, 0)

out1 = model(x1)
out2 = model(x2)
out12 = torch.cat((out1, out2), 0)
out12.mean().backward()