Hello there, I’m working on transformer based GAN and i have trouble training the model. I know that GANs are hard to train and choosing proper hyperparameters can be tricky but I think there might be something wrong with my model. I’m quite new to torch so I don’t see where the problem might be. During the training, discriminator’s loss after few examples goes to 0 while generator’s keeps rising. Even when simplifying disc model to only two embeddings and sequential layers I cannot get generator to produce meaningful data. I think that code itself is quite straightforward but if you have some questions please ask so we can find problems.
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from random import randint
from torch.utils.data import Dataset, DataLoader
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
DEVICE = "cuda"
N_LAYERS = 2
SEQ_SIZE = 48
BATCH_SIZE = 32
N_HEAD = 6
HIDDEN_SIZE = 128
LR = 3e-4
# some data avialable here https://pastebin.com/LLa9baAQ
with open("./jsbdataset.txt", "r") as file:
data = file.read().strip().split(" ")
DATA_LENGHT = len(data)
VOCAB_SIZE = len(set(data))
EMBED_DIM = 360
class JBChorales(Dataset):
def __init__(self):
self.raw_data = [int(num) - 35 if num == 0 else 0 for num in data]
self.data = []
for _ in range(DATA_LENGHT // SEQ_SIZE):
rand_int = randint(0, DATA_LENGHT - SEQ_SIZE)
remainder = rand_int % 4
rand_div_nr = rand_int - remainder
self.data.append(self.raw_data[rand_div_nr : rand_div_nr + SEQ_SIZE])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.tensor(self.data[idx])
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.gen_embed = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
self.pos_embed = nn.Embedding(SEQ_SIZE, EMBED_DIM)
self.enc_layer = nn.TransformerEncoderLayer(EMBED_DIM, N_HEAD, batch_first=True)
self.encoder = nn.TransformerEncoder(self.enc_layer, N_LAYERS)
self.dense = nn.Linear(EMBED_DIM, VOCAB_SIZE)
def forward(self, x: torch.Tensor):
cur_seq_size = x.size()[1]
if cur_seq_size > SEQ_SIZE:
x = x[:, :-SEQ_SIZE, :]
mask = nn.Transformer.generate_square_subsequent_mask(cur_seq_size)
x = self.gen_embed(x) # -> BATCH_SIZE x SEQ_SIZE x EMBED_DIM
pos = self.pos_embed(torch.arange(cur_seq_size, device=DEVICE))
x = x + pos
x = self.encoder.forward(x, mask=mask, is_causal=True)
x = self.dense(x) # -> BATCH_SIZE x SEQ_SIZE x VOCAB_SIZE
x = x[:, -1, :] # -> BATCH_SIZE x 1 x VOCAB_SIZE
x = F.softmax(x, dim=-1)
return x.view(BATCH_SIZE, VOCAB_SIZE) # -> BATCH_SIZE x VOCAB_SIZE
def generate(self, previous: torch.Tensor, n_iter: int) -> torch.Tensor:
for i in range(n_iter - 1):
probs = self.forward(previous)
chosen = torch.argmax(probs, dim=1, keepdim=True)
previous = torch.cat([previous, chosen], dim=1)
return previous
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.disc_embed = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
self.pos_embed = nn.Embedding(SEQ_SIZE, EMBED_DIM)
self.enc_layer = nn.TransformerEncoderLayer(EMBED_DIM, N_HEAD, batch_first=True)
self.encoder = nn.TransformerEncoder(self.enc_layer, N_LAYERS)
self.classifier = nn.Sequential(nn.Linear(EMBED_DIM, 1), nn.Sigmoid())
def forward(self, x):
mask = nn.Transformer.generate_square_subsequent_mask(SEQ_SIZE)
x = self.disc_embed(x) # -> BATCH_SIZE x OUT_SIZE x EMBED_DIM
pos = self.pos_embed(torch.arange(x.size()[1], device=DEVICE))
x = x + pos
x = self.encoder.forward(
x, mask=mask, is_causal=True
) # -> BATCH_SIZE x OUT_SIZE x EMBED_DIM
x = x.mean(dim=1) # -> BATCH_SIZE x EMBED_DIM
x = self.classifier(x) # -> BATCH_SIZE x 1
return x
# TRAINING LOOP
# DATA
dataset = JBChorales()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# MODELS
gen = Generator().to(DEVICE)
disc = Discriminator().to(DEVICE)
# SOME OTHER STUFF
criterion = nn.BCELoss()
optim_disc = torch.optim.Adam(disc.parameters(), lr=LR)
optim_gen = torch.optim.Adam(gen.parameters(), lr=LR)
step = 0
NUM_EPOCHS = 350
real: torch.Tensor
for epoch in range(NUM_EPOCHS):
for batch_idx, real in enumerate(dataloader):
real = real.to(DEVICE) # -> BATCH_SIZE x SEQ_SIZE
# train discriminator
fake = gen.generate(
torch.randint(1, VOCAB_SIZE, size=(BATCH_SIZE, 1)).to(DEVICE), SEQ_SIZE
) # -> BATCH x SEQ
# print(fake)
disc_real = disc(real) # -> BATCH x 1
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake) # -> BATCH x 1
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_fake + lossD_real) / 2
disc.zero_grad()
lossD.backward(retain_graph=True)
optim_disc.step()
# train generator
output = disc(fake) # -> BATCH x 1
lossG = criterion(output, torch.ones_like(output))
gen.zero_grad()
lossG.backward()
optim_gen.step()
if batch_idx == 0:
print(
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
)
Thanks in advance for any help!