I currently have the following tensor:

```
torch.Size([2, 2, 128, 1, 1])
```

Normally the batch dimension is at the start, but here I’ve split the tensor on the number of channels. I.e.,

```
(2, 512) => (2, 2, 128, 1, 1)
```

- 1st dimension is the “split” dimension
- 2nd dimension is batch dimension
- 3rd dimension is # of channels
- Last 2 dimensions are 2 dimensions I added for broadcasting

Right now, I do the following:

```
cond_w, cond_b = self.emb_layers(emb)
assert cond_w.shape == (N, self.out_channels, 1, 1)
x = self.out_norm(x) * (1 + cond_w) + cond_b
```

so I just index into the returned tensor and split it into 2 child tensors.

I’m wondering if this will effect backprop, as I never see people do this.

The alternative is:

```
scale_shift = self.emb_layers(emb)
cond_w, cond_b = th.chunk(scale_shift, 2, dim=1)
cond_w = cond_w.squeeze(1)
cond_b = cond_b.squeeze(1)
assert cond_w.shape == cond_b.shape
assert cond_w.shape == (N, self.out_channels, 1, 1)
```

is one method preferable to the alternative?