How to enforce symmetry on generator output?

I’m having a lot of issues implementing this paper, Generative Modeling for Protein Structures. I’m using the DCGAN tutorial as a reference since their architectures are similar.

Context:

The dataloader in the tutorial is implemented as:

dataset = dset.ImageFolder(...)
for i, data in enumerate(dataloader, 0):
    real_cpu = data[0].to(device)

and gives shapes:

print(data[0].shape)
print(data[1].shape)
torch.Size([128, 3, 64, 64])
torch.Size([128])

Whereas my dataset is implemented as:

with h5py.File('/home/collin/protein_maps/dataset.hdf5', 'r') as f:
    x = f['train_64'][:]
dataloader = torch.utils.data.DataLoader(x, ...)

for i, data in enumerate(dataloader, 0):
    # Unsqueezed dim one to convert [128, 64, 64] to [128, 1, 64, 64] to conform to D architecture 
    real_cpu = (data.unsqueeze(dim=1).type(torch.FloatTensor)).to(device)

This causes problems later on when I want to enforce symmetry later when passing from the generator to the discriminator as specified in the original paper:

During training, we enforce that G(z) be positive by clamping output values above zero and symmetric by setting G(z) = (G(z)+G(z).T)/2 before passing the generated map to the discriminator.

but I get a broadcasting error when trying to broadcast [128, 1, 64, 64] and [64, 64, 1, 128] which makes sense. The authors specify that this step be done in the generator architecture verbatim:

  Model architectures. Each layer is presented as:
   
  Layer(filters, kernel size, stride, padding)

  ------------------64 GAN------------------
  down-scale factor = 100
  
  --Generator--
  nz = 100
  ConvTranspose2d( 512, 4, 1, 0)
  BatchNorm2d(512)
  LeakyReLU(0.2),
  ConvTranspose2d(256, 4, 2, 1)
  BatchNorm2d(256)
  LeakyReLU(0.2)
  ConvTranspose2d(128, 4, 2, 1)
  BatchNorm2d(128)
  LeakyReLU(0.2)
  ConvTranspose2d(64, 4, 2, 1)
  BatchNorm2d(64)
  LeakyReLU(0.2)
  ConvTranspose2d(1, 4, 2, 1)
  Clamp(>0)
  Enforce Symmetry

but I’m not sure how to do this in practice as my architecture looks like:

self.main = nn.Sequential(
    nn.ConvTranspose2d(nz, 512, kernel_size=4, stride=1, padding=0),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2),
    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
)

Question

How can I clamp and enfore symmetry in this part of my architecture? Is this possible or do I have to do this while training? If I do it in training what’s the best way to slice only the [64, 64] in [128, 1, 64, 64] and reinsert it after? Is there a function to enforce symmetry in pytorch in nn Sequential?

I’m not familiar with the paper, but wouldn’t this work:

ret = (real_cpu + real_cpu.permute(0, 1, 3, 2)) / 2

Thank you for your reply. This worked but now it seems like my generator isn’t updating.

Here is my training loop:

# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        # Unsqueezed dim one to convert [128, 64, 64] to [128, 1, 64, 64] to conform to D architecture 
        real_cpu = (data.unsqueeze(dim=1).type(torch.FloatTensor)).to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Make Symmetric
        sym_fake = (fake.detach().clamp(min=0) + fake.detach().clamp(min=0).permute(0, 1, 3, 2)) / 2
        # Classify all fake batch with D
        output = netD(sym_fake).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        #adjust_optim(optimizerD, iters)
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake.detach()).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        adjust_optim(optimizerG, iters)

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

I used your function on the output of the generator when feeding to the discriminator but my generator fails to learn after this.

image