PyTorch Gan training: shape mismatch between real and fake tensor

’m working on a custom CTGAN-like model for tabular data, where I implemented my own _collate_fn and run_step() logic to control sampling, noise injection, and discriminator inputs.

However, I consistently run into shape mismatch errors between the real and fake data batches during training — particularly in the discriminator step.

What I know so far:

  • I already apply drop_last=True in my DataLoader.
  • I try to slice all tensors in _collate_fn to a shared min_len, including disc_in_known, disc_in_unknown, disc_in_fakez, disc_in_c, disc_in_perm.
  • In run_step(), I slice real/fake tensors before loss calculation.
  • Still, I get errors like:
ValueError: shape mismatch: real_mean=torch.Size([6823]), fake_mean=torch.Size([6814])

I’m using a custom collate_fn() because I need to sample multiple times per batch for the discrimination. Inject Gaussian noise and slice all inputs to the same minimum length to avoid mismatches. This function prepares both the discriminator and generator inputs before passing them into the training loop.

def _collate_fn(self, batch: List[Tuple[Tensor, ...]]) -> Tuple[Tensor, ...]:
    batch_size = len(batch)
    mean = torch.zeros(batch_size, self._embedding_dim)
    std = mean + 1

    disc_in_known, disc_in_unknown, disc_in_fakez = [], [], []
    disc_in_c, disc_in_perm = [], []

    for _ in range(self._discriminator_step):
        fakez = torch.normal(mean=mean, std=std)
        c1, m1, col, opt = self._sampler.sample_condvec(batch_size)
        known, unknown = self._sampler.sample_data(batch_size, col, opt)
        perm = torch.randperm(batch_size)

        disc_in_known.append(known)
        disc_in_unknown.append(unknown)
        disc_in_fakez.append(fakez)
        disc_in_c.append(c1)
        disc_in_perm.append(perm)

    # Use min_len to align shapes across all tensors
    min_len = min(k.shape[0] for k in disc_in_known + disc_in_fakez + disc_in_c + disc_in_perm + disc_in_unknown)
    disc_in_known = torch.stack([k[:min_len] for k in disc_in_known])
    disc_in_unknown = torch.stack([k[:min_len] for k in disc_in_unknown])
    disc_in_fakez = torch.stack([k[:min_len] for k in disc_in_fakez])
    disc_in_c = torch.stack([k[:min_len] for k in disc_in_c])
    disc_in_perm = torch.stack([k[:min_len] for k in disc_in_perm])

    # generator input
    gen_in_fakez = torch.normal(mean=mean, std=std)
    c1, m1, col, opt = self._sampler.sample_condvec(batch_size)
    known, unknown = self._sampler.sample_data(batch_size, col, opt)

    return disc_in_known, disc_in_unknown, disc_in_fakez, disc_in_c, disc_in_perm, known, unknown, gen_in_fakez, c1, m1

Despite all the slicing and drop_last=True, I still get shape mismatch errors during training. Could there be another source of inconsistency I’m overlooking?

Any suggestions on how to ensure real and fake batches always align would be appreciated.

Could you describe where the shape mismatch error is raised and if you’ve already checked the batch shapes before the error is raised? Also, do the printed shapes look expected? I.e. is 6823 or 6814 representing e.g. the batch size or any other shape you’ve specified?

Not sure if it’s just confusing variable naming, but batch_size = len(batch) is the number of items (read: tensors) in each batch, whereas the actual batch size is batch_size = batch[0].shape[0].