I currently have the following tensor:
torch.Size([2, 2, 128, 1, 1])
Normally the batch dimension is at the start, but here I’ve split the tensor on the number of channels. I.e.,
(2, 512) => (2, 2, 128, 1, 1)
- 1st dimension is the “split” dimension
- 2nd dimension is batch dimension
- 3rd dimension is # of channels
- Last 2 dimensions are 2 dimensions I added for broadcasting
Right now, I do the following:
cond_w, cond_b = self.emb_layers(emb)
assert cond_w.shape == (N, self.out_channels, 1, 1)
x = self.out_norm(x) * (1 + cond_w) + cond_b
so I just index into the returned tensor and split it into 2 child tensors.
I’m wondering if this will effect backprop, as I never see people do this.
The alternative is:
scale_shift = self.emb_layers(emb)
cond_w, cond_b = th.chunk(scale_shift, 2, dim=1)
cond_w = cond_w.squeeze(1)
cond_b = cond_b.squeeze(1)
assert cond_w.shape == cond_b.shape
assert cond_w.shape == (N, self.out_channels, 1, 1)
is one method preferable to the alternative?