Transfoermer based GAN not training properly

Hello there, I’m working on transformer based GAN and i have trouble training the model. I know that GANs are hard to train and choosing proper hyperparameters can be tricky but I think there might be something wrong with my model. I’m quite new to torch so I don’t see where the problem might be. During the training, discriminator’s loss after few examples goes to 0 while generator’s keeps rising. Even when simplifying disc model to only two embeddings and sequential layers I cannot get generator to produce meaningful data. I think that code itself is quite straightforward but if you have some questions please ask so we can find problems.

import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from random import randint
from torch.utils.data import Dataset, DataLoader

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

DEVICE = "cuda"
N_LAYERS = 2
SEQ_SIZE = 48
BATCH_SIZE = 32
N_HEAD = 6
HIDDEN_SIZE = 128
LR = 3e-4

# some data avialable here https://pastebin.com/LLa9baAQ
with open("./jsbdataset.txt", "r") as file:
    data = file.read().strip().split(" ")

DATA_LENGHT = len(data)
VOCAB_SIZE = len(set(data))

EMBED_DIM = 360


class JBChorales(Dataset):
    def __init__(self):
        self.raw_data = [int(num) - 35 if num == 0 else 0 for num in data]
        self.data = []
        for _ in range(DATA_LENGHT // SEQ_SIZE):
            rand_int = randint(0, DATA_LENGHT - SEQ_SIZE)
            remainder = rand_int % 4
            rand_div_nr = rand_int - remainder
            self.data.append(self.raw_data[rand_div_nr : rand_div_nr + SEQ_SIZE])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx])


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen_embed = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.pos_embed = nn.Embedding(SEQ_SIZE, EMBED_DIM)
        self.enc_layer = nn.TransformerEncoderLayer(EMBED_DIM, N_HEAD, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.enc_layer, N_LAYERS)
        self.dense = nn.Linear(EMBED_DIM, VOCAB_SIZE)

    def forward(self, x: torch.Tensor):
        cur_seq_size = x.size()[1]
        if cur_seq_size > SEQ_SIZE:
            x = x[:, :-SEQ_SIZE, :]
        mask = nn.Transformer.generate_square_subsequent_mask(cur_seq_size)
        x = self.gen_embed(x)  # -> BATCH_SIZE x SEQ_SIZE x EMBED_DIM
        pos = self.pos_embed(torch.arange(cur_seq_size, device=DEVICE))
        x = x + pos
        x = self.encoder.forward(x, mask=mask, is_causal=True)
        x = self.dense(x)  # -> BATCH_SIZE x SEQ_SIZE x VOCAB_SIZE
        x = x[:, -1, :]  # -> BATCH_SIZE x 1 x VOCAB_SIZE
        x = F.softmax(x, dim=-1)
        return x.view(BATCH_SIZE, VOCAB_SIZE)  # -> BATCH_SIZE x VOCAB_SIZE

    def generate(self, previous: torch.Tensor, n_iter: int) -> torch.Tensor:
        for i in range(n_iter - 1):
            probs = self.forward(previous)
            chosen = torch.argmax(probs, dim=1, keepdim=True)
            previous = torch.cat([previous, chosen], dim=1)
        return previous


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc_embed = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.pos_embed = nn.Embedding(SEQ_SIZE, EMBED_DIM)
        self.enc_layer = nn.TransformerEncoderLayer(EMBED_DIM, N_HEAD, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.enc_layer, N_LAYERS)
        self.classifier = nn.Sequential(nn.Linear(EMBED_DIM, 1), nn.Sigmoid())

    def forward(self, x):
        mask = nn.Transformer.generate_square_subsequent_mask(SEQ_SIZE)
        x = self.disc_embed(x)  # -> BATCH_SIZE x OUT_SIZE x EMBED_DIM
        pos = self.pos_embed(torch.arange(x.size()[1], device=DEVICE))
        x = x + pos
        x = self.encoder.forward(
            x, mask=mask, is_causal=True
        )  # -> BATCH_SIZE x OUT_SIZE x EMBED_DIM
        x = x.mean(dim=1)  # -> BATCH_SIZE x EMBED_DIM
        x = self.classifier(x)  # -> BATCH_SIZE x 1
        return x


# TRAINING LOOP

# DATA
dataset = JBChorales()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# MODELS
gen = Generator().to(DEVICE)
disc = Discriminator().to(DEVICE)

# SOME OTHER STUFF
criterion = nn.BCELoss()
optim_disc = torch.optim.Adam(disc.parameters(), lr=LR)
optim_gen = torch.optim.Adam(gen.parameters(), lr=LR)

step = 0
NUM_EPOCHS = 350

real: torch.Tensor
for epoch in range(NUM_EPOCHS):
    for batch_idx, real in enumerate(dataloader):
        real = real.to(DEVICE)  # -> BATCH_SIZE x SEQ_SIZE

        # train discriminator
        fake = gen.generate(
            torch.randint(1, VOCAB_SIZE, size=(BATCH_SIZE, 1)).to(DEVICE), SEQ_SIZE
        )  # -> BATCH x SEQ

        # print(fake)

        disc_real = disc(real)  # -> BATCH x 1
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake)  # -> BATCH x 1
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_fake + lossD_real) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optim_disc.step()

        # train generator
        output = disc(fake)  # -> BATCH x 1
        lossG = criterion(output, torch.ones_like(output))

        gen.zero_grad()
        lossG.backward()
        optim_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

Thanks in advance for any help!