Hello, I am trying to train a WGAN-Gp on one hot encoded data using the architecture developed by: https://github.com/av1659/fbgan. My architecture is almost similar, the only difference is the training script and the input data. The problem is that after some batch iterations in the training the loss and the weights go to nan.
here is the training script:
#!/usr/bin/env python3
-- coding: utf-8 --
“”"
Created on Tue Nov 10 13:37:26 2020
@author: ahtisham
“”"
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch import autograd
from sklearn.preprocessing import OneHotEncoder
import torch.nn as nn
import torch.functional as F
from torch.autograd import Variable, grad
import matplotlib.pyplot as plt
import numpy as np
import os.path
from src.gumbel import *
from src.parser import parameter_parser
from src.utils import *
from src.models import Discriminator
from src.models import Generator
#from enhancer_classifer import EnhancerClassifier
#from classifier_parser import parameter_parser2
import matplotlib.pyplot as plt
class WEnhancerGAN:
def init(self, args, num_chars=4):
# function to retrieve dataset (augmented enhancers)
self.init_data(args)
# assign the parameters from args parser
self.batch_size = args.batch_size
self.hidden = args.hidden_dim
self.lr = args.learning_rate
self.epochs = args.epochs
self.sequence_length = args.max_len
self.discriminator_steps = args.discriminator_steps
self.generator_steps = args.generator_steps
self.directory = args.directory
self.lam = args.lam
self.num_chars = num_chars
self.gpweight = 10
#self.e_classifier = self.load_Enhancer_Classifier()
# call preprocessing class from utils
self.preprocessing = Preprocessing(args)
self.device = "cuda:2"
self.build_GAN_model()
def __init_data__(self, args):
# function used from the utils files (see utils for details)
self.preprocessing = Preprocessing(args)
# read fasta of positive sequences (enhancers)
self.preprocessing.load_data()
self.jan_seq = self.preprocessing.longer_sequences
# self.preprocessing.write_long_seq_file()
if (os.path.exists("oneHotEncodedData.npy")):
self.data = np.load("oneHotEncodedData.npy")
#self.data = self.data[1176130:].astype(float)
print("One hot encoded data present !!! \nShape :" ,self.data.shape)
else:
print("Reading and One hot encoding the sequences")
self.preprocessing.sequencesToOneHotEncoding()
self.data = np.load("oneHotEncodedData.npy")
print("Shape of Read Data:",self.data.shape)
self.data = self.data[1176130:]
#print("blalalalal", self.data[0])
def build_GAN_model(self):
# defining the models
#print(self.num_chars, self.sequence_length, self.batch_size, self.hidden)
self.Generator = Generator(self.num_chars, self.sequence_length, self.batch_size, self.hidden).to(self.device)
self.Discriminator = Discriminator(self.num_chars, self.sequence_length, self.batch_size, self.hidden).to(self.device)
# defining the optimizers
self.d_optim = optim.Adam(self.Discriminator.parameters(), lr=self.lr, betas=(0.5, 0.9))
self.g_optim = optim.Adam(self.Generator.parameters(), lr=self.lr, betas=(0.5, 0.9))
print("Models have been built...")
def Interpolate(self, real_seqs, fake_seqs):
N = real_seqs.shape[0]
theta = torch.tensor(np.random.uniform(size = N), dtype= torch.float).view(N, 1,1,1).to(self.device)
sample = theta * real_seqs + (1-theta) * fake_seqs
return sample
def Gradient_Norm(self, real_data, fake_data):
alpha = torch.rand(self.batch_size, 1, 1)
alpha = alpha.view(-1, 1, 1)
alpha = alpha.expand_as(real_data)
alpha = alpha.to(self.device)
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.to(self.device)
interpolates = autograd.Variable(interpolates, requires_grad=True)
#interpolates = interpolates + 1e-16
disc_interpolates = self.Discriminator(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
#gradients = gradients + 1e-16
#gradient_penalty = ((gradients.norm(2, dim=1).norm(2, dim=1) - 1) ** 2)
gradient_penalty = torch.mean((1. - torch.sqrt(1e-8 + torch.sum(gradients.reshape(gradients.size(0), -1) ** 2, dim=1))) ** 2)
return gradient_penalty
'''
def Gradient_Norm(self, model, real_seqs, fake_seqs):
N = real_seqs.shape[0]
_input = self.Interpolate(real_seqs, fake_seqs)
_input = Variable(_input, requires_grad = True)
score = model(_input)
ouputs = torch.ones(score.shape).to(self.device)
gradient = grad( outputs= score,
inputs= _input,
create_graph=True,
retain_graph= True)[0]
grad_norm = torch.sqrt(torch.sum(gradient.view(N, -1) **2, dim=1) + 1e-12)
return (grad_norm - 1) ** 2
‘’’
def Wasserstein_Loss(self, labels, predictions):
return torch.mean(labels * predictions)
def load_Enhancer_Classifier(self):
args = parameter_parser2()
# load the enhancer classifier class
model = EnhancerClassifier(args).to(self.device)
# load its state dictionary
model.load_state_dict(torch.load("model"))
# return the model
return model
# convert
def one_Hot_To_Tokenizer(self, onehot):
print(onehot)
for i in range(300):
temp = onehot[:, i]
n = np.argmax(temp)
if n == 0:
print('A', end='')
elif n == 1:
print('G', end='')
elif n == 2:
print('C', end='')
elif n == 3:
print('T', end='')
def predict_Enhancer_Sequence(self, one_hot_seqs):
predictions = []
for i in range(self.batch_size):
predictions.append(self.e_classifier(self.one_Hot_To_Tokenizer(one_hot_seqs[i])))
def tokenize_string(self,sample):
return tuple(sample.lower().split(' '))
def check_ahtisham(self):
for i in range(len(self.data)):
for j in range(300-1):
temp = self.data[i][j][:]
n = np.argmax(temp)
flag = 0
if n == 0:
flag = 0
elif n == 1:
flag = 1
elif n == 2:
flag = 1
elif n == 3:
flag = 1
print(flag)
def check_data(self):
# check in each row the data is like 1,0,0,0 or 0,1,0,0 or 0,0,1,0
correct = []
falseRows = []
for i in range(len(self.data)):
for j in range(300-1):
control = self.data[i,j,:]
flag = False
for value in control:
if value == 1:
if flag:
falseRows.append((i,j))
correct.append(False)
break
flag = True
if flag:
correct.append(True)
else:
correct.append(False)
falseRows.append((i, j))
return False if False in correct else True
def train_WEnhancerGAN(self):
loader = DataLoader(self.data, batch_size=self.batch_size, drop_last=True)
self.g_loss_a = []
self.d_loss_a = []
self.w_dist_a = []
self.gp_a = []
# define the epochs
epochs = 10
# define the lists for d loss on real and fake data
d_fake_losses, d_real_losses = [], []
# list for gradient penalties
gradient_penalties = []
counter = 0
latent_dimensions = 128
for epoch in tqdm(range(10)):
for i,batch in enumerate(loader):
# perform label smoothing
#noise_label = torch.randn(self.batch_size) #* 0.1
# assign the labels
real_labels = (torch.ones(self.batch_size)).to(self.device)
fake_labels = - torch.ones(self.batch_size).to(self.device)
###### ** train the discriminator ** #######
# avg discriminator loss
d_loss_avg = 0
real_seqs = batch.type(torch.FloatTensor).to(self.device)
for _ in range (self.discriminator_steps):
# set the current gradient to zero
self.d_optim.zero_grad()
# generate sequences from the latent space
latent_vectors = torch.randn(self.batch_size, latent_dimensions).to(self.device)
#print("booboloski", latent_vectors.size())
fake_seqs = self.Generator(latent_vectors)
# score the sequences
real_score = self.Discriminator(real_seqs)
fake_score = self.Discriminator(fake_seqs.detach())
# calculate the gradient penalt+
gradient_penalty = self.Gradient_Norm(real_seqs, fake_seqs).mean()
# discirminator loss
d_loss = self.Wasserstein_Loss(real_labels, real_score) - self.Wasserstein_Loss(fake_labels, fake_score) + gradient_penalty * self.gpweight
# calc grads
d_loss.backward()
# apply the grads to the weights
self.d_optim.step()
# append the loss
# d_loss_avg += d_loss
self.d_loss_a.append(d_loss.item())
###### *** train the generator *** #####
#set gradients to zero
self.g_optim.zero_grad()
# generate images from the latent space
latent_vectors = torch.randn(self.batch_size, latent_dimensions).to(self.device)
fake_seqs = self.Generator(latent_vectors)
fake_s = self.Discriminator(fake_seqs)
g_loss = self.Wasserstein_Loss(fake_labels, fake_s)
g_loss.backward()
self.g_optim.step()
self.g_loss_a.append(g_loss.item())
# append the g loss in the list
#self.g_loss_a.append(g_loss.item())
print("Generator's Loss = ", self.g_loss_a[-1])
print("Generator's Loss = ", self.g_loss_a[-1], "Discriminator's Loss:", self.d_loss_a[-1])
args = parameter_parser()
wgan = WEnhancerGAN(args)
#print(“single tensor shape:”,wgan.data[0].shape)
wgan.build_GAN_model()
#print(“new data shape:”, wgan.data.shape)
print(wgan.check_data())
wgan.train_WEnhancerGAN()
#wgan.check_ahtisham()
‘’’
plt.plot(wgan.g_loss_a)
plt.plot(wgan.d_loss_a)
plt.plot(wgan.gp_a)
plt.show()
latent_vector = torch.randn(size=(128,)).to(“cuda:3”)
print(wgan.Generator(latent_vector))
‘’’
Any help would be appreciated. The input data is one hot encoded (64, 300, 4)(batch, len, one_hot_channels)