If I don't use `torch.chunk` will I have backprop errors?

I currently have the following tensor:

torch.Size([2, 2, 128, 1, 1])

Normally the batch dimension is at the start, but here I’ve split the tensor on the number of channels. I.e.,

(2, 512) => (2, 2, 128, 1, 1)
  • 1st dimension is the “split” dimension
  • 2nd dimension is batch dimension
  • 3rd dimension is # of channels
  • Last 2 dimensions are 2 dimensions I added for broadcasting

Right now, I do the following:

        cond_w, cond_b = self.emb_layers(emb)
        assert cond_w.shape == (N, self.out_channels, 1, 1)
        x = self.out_norm(x) * (1 + cond_w) + cond_b

so I just index into the returned tensor and split it into 2 child tensors.
I’m wondering if this will effect backprop, as I never see people do this.

The alternative is:

        scale_shift = self.emb_layers(emb)
        cond_w, cond_b = th.chunk(scale_shift, 2, dim=1)
        cond_w = cond_w.squeeze(1)
        cond_b = cond_b.squeeze(1)
        assert cond_w.shape == cond_b.shape
        assert cond_w.shape == (N, self.out_channels, 1, 1)

is one method preferable to the alternative?

torch.chunk will not detach the tensors from the computation graph and will yield the same results as seen here:

x = torch.randn(10, 3, 224, 224, requires_grad=True)
model = models.resnet18().eval()

# full batch
x.grad = None
out = model(x)
out.mean().backward()
print(x.grad.abs().sum())
# tensor(1.1471)

# chunk
x.grad = None
x1, x2 = torch.chunk(x, 2, 0)

out1 = model(x1)
out2 = model(x2)
out12 = torch.cat((out1, out2), 0)
out12.mean().backward()

print(x.grad.abs().sum())
# tensor(1.1471)

(assuming all other operations are equal).

1 Like