’m working on a custom CTGAN-like model for tabular data, where I implemented my own _collate_fn
and run_step() logic to control sampling, noise injection, and discriminator inputs.
However, I consistently run into shape mismatch errors between the real and fake data batches during training — particularly in the discriminator step.
What I know so far:
- I already apply drop_last=True in my DataLoader.
- I try to slice all tensors in _collate_fn to a shared min_len, including disc_in_known, disc_in_unknown, disc_in_fakez, disc_in_c, disc_in_perm.
- In run_step(), I slice real/fake tensors before loss calculation.
- Still, I get errors like:
ValueError: shape mismatch: real_mean=torch.Size([6823]), fake_mean=torch.Size([6814])
I’m using a custom collate_fn()
because I need to sample multiple times per batch for the discrimination. Inject Gaussian noise and slice all inputs to the same minimum length to avoid mismatches. This function prepares both the discriminator and generator inputs before passing them into the training loop.
def _collate_fn(self, batch: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
batch_size = len(batch)
mean = torch.zeros(batch_size, self._embedding_dim)
std = mean + 1
disc_in_known, disc_in_unknown, disc_in_fakez = [], [], []
disc_in_c, disc_in_perm = [], []
for _ in range(self._discriminator_step):
fakez = torch.normal(mean=mean, std=std)
c1, m1, col, opt = self._sampler.sample_condvec(batch_size)
known, unknown = self._sampler.sample_data(batch_size, col, opt)
perm = torch.randperm(batch_size)
disc_in_known.append(known)
disc_in_unknown.append(unknown)
disc_in_fakez.append(fakez)
disc_in_c.append(c1)
disc_in_perm.append(perm)
# Use min_len to align shapes across all tensors
min_len = min(k.shape[0] for k in disc_in_known + disc_in_fakez + disc_in_c + disc_in_perm + disc_in_unknown)
disc_in_known = torch.stack([k[:min_len] for k in disc_in_known])
disc_in_unknown = torch.stack([k[:min_len] for k in disc_in_unknown])
disc_in_fakez = torch.stack([k[:min_len] for k in disc_in_fakez])
disc_in_c = torch.stack([k[:min_len] for k in disc_in_c])
disc_in_perm = torch.stack([k[:min_len] for k in disc_in_perm])
# generator input
gen_in_fakez = torch.normal(mean=mean, std=std)
c1, m1, col, opt = self._sampler.sample_condvec(batch_size)
known, unknown = self._sampler.sample_data(batch_size, col, opt)
return disc_in_known, disc_in_unknown, disc_in_fakez, disc_in_c, disc_in_perm, known, unknown, gen_in_fakez, c1, m1
Despite all the slicing and drop_last=True, I still get shape mismatch errors during training. Could there be another source of inconsistency I’m overlooking?
Any suggestions on how to ensure real and fake batches always align would be appreciated.