NANs while training WGAN-GP on one hot encoded data

Hello, I am trying to train a WGAN-Gp on one hot encoded data using the architecture developed by: https://github.com/av1659/fbgan. My architecture is almost similar, the only difference is the training script and the input data. The problem is that after some batch iterations in the training the loss and the weights go to nan.

here is the training script:

#!/usr/bin/env python3

-- coding: utf-8 --

“”"
Created on Tue Nov 10 13:37:26 2020

@author: ahtisham
“”"
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch import autograd
from sklearn.preprocessing import OneHotEncoder
import torch.nn as nn
import torch.functional as F
from torch.autograd import Variable, grad

import matplotlib.pyplot as plt
import numpy as np
import os.path
from src.gumbel import *
from src.parser import parameter_parser
from src.utils import *
from src.models import Discriminator
from src.models import Generator
#from enhancer_classifer import EnhancerClassifier
#from classifier_parser import parameter_parser2
import matplotlib.pyplot as plt

class WEnhancerGAN:
def init(self, args, num_chars=4):
# function to retrieve dataset (augmented enhancers)
self.init_data(args)

    # assign the parameters from args parser
    self.batch_size = args.batch_size
    self.hidden = args.hidden_dim
    self.lr = args.learning_rate
    self.epochs = args.epochs
    self.sequence_length = args.max_len
    self.discriminator_steps = args.discriminator_steps
    self.generator_steps = args.generator_steps
    self.directory = args.directory
    self.lam = args.lam
    self.num_chars = num_chars
    self.gpweight = 10
    #self.e_classifier = self.load_Enhancer_Classifier()

    # call preprocessing class from utils
    self.preprocessing = Preprocessing(args)

    self.device = "cuda:2"
    self.build_GAN_model()

def __init_data__(self, args):

    # function used from the utils files (see utils for details)
    self.preprocessing = Preprocessing(args)

    # read fasta of positive sequences (enhancers)
    self.preprocessing.load_data()
    self.jan_seq = self.preprocessing.longer_sequences
   # self.preprocessing.write_long_seq_file()

    if (os.path.exists("oneHotEncodedData.npy")):
        self.data = np.load("oneHotEncodedData.npy")
        #self.data = self.data[1176130:].astype(float)
        print("One hot encoded data present !!! \nShape :" ,self.data.shape)
    else:
        print("Reading and One hot encoding the sequences")
        self.preprocessing.sequencesToOneHotEncoding()
        self.data = np.load("oneHotEncodedData.npy")
        print("Shape of Read Data:",self.data.shape)
        self.data = self.data[1176130:]

        #print("blalalalal", self.data[0])

def build_GAN_model(self):

    # defining the models
    #print(self.num_chars, self.sequence_length, self.batch_size, self.hidden)
    self.Generator = Generator(self.num_chars, self.sequence_length, self.batch_size, self.hidden).to(self.device)
    self.Discriminator = Discriminator(self.num_chars, self.sequence_length, self.batch_size, self.hidden).to(self.device)

    # defining the optimizers
    self.d_optim = optim.Adam(self.Discriminator.parameters(), lr=self.lr, betas=(0.5, 0.9))
    self.g_optim = optim.Adam(self.Generator.parameters(), lr=self.lr, betas=(0.5, 0.9))

    print("Models have been built...")

def Interpolate(self, real_seqs, fake_seqs):
    N = real_seqs.shape[0]
    theta = torch.tensor(np.random.uniform(size = N), dtype= torch.float).view(N, 1,1,1).to(self.device)
    sample = theta * real_seqs + (1-theta) * fake_seqs
    return sample

def Gradient_Norm(self, real_data, fake_data):
    alpha = torch.rand(self.batch_size, 1, 1)
    alpha = alpha.view(-1, 1, 1)
    alpha = alpha.expand_as(real_data)
    alpha = alpha.to(self.device)
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = interpolates.to(self.device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    #interpolates = interpolates + 1e-16
    disc_interpolates = self.Discriminator(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    #gradients = gradients + 1e-16
    #gradient_penalty = ((gradients.norm(2, dim=1).norm(2, dim=1) - 1) ** 2)
    gradient_penalty = torch.mean((1. - torch.sqrt(1e-8 + torch.sum(gradients.reshape(gradients.size(0), -1) ** 2, dim=1))) ** 2)
    return gradient_penalty
'''
def Gradient_Norm(self, model, real_seqs, fake_seqs):
    N = real_seqs.shape[0]

    _input = self.Interpolate(real_seqs, fake_seqs)

    _input = Variable(_input, requires_grad = True)

    score = model(_input)
    ouputs = torch.ones(score.shape).to(self.device)
    gradient = grad( outputs= score,
                     inputs= _input,
                     create_graph=True,
                     retain_graph= True)[0]
    grad_norm = torch.sqrt(torch.sum(gradient.view(N, -1) **2, dim=1) + 1e-12)
    return (grad_norm - 1) ** 2

‘’’
def Wasserstein_Loss(self, labels, predictions):
return torch.mean(labels * predictions)

def load_Enhancer_Classifier(self):

    args = parameter_parser2()

    # load the enhancer classifier class
    model = EnhancerClassifier(args).to(self.device)

    # load its state dictionary
    model.load_state_dict(torch.load("model"))

    # return the model
    return model

# convert
def one_Hot_To_Tokenizer(self, onehot):
    print(onehot)
    for i in range(300):
        temp = onehot[:, i]
        n = np.argmax(temp)
        if n == 0:
            print('A', end='')
        elif n == 1:
            print('G', end='')
        elif n == 2:
            print('C', end='')
        elif n == 3:
            print('T', end='')

def predict_Enhancer_Sequence(self, one_hot_seqs):
    predictions = []

    for i in range(self.batch_size):
        predictions.append(self.e_classifier(self.one_Hot_To_Tokenizer(one_hot_seqs[i])))

def tokenize_string(self,sample):
    return tuple(sample.lower().split(' '))

def check_ahtisham(self):
    for i in range(len(self.data)):
        for j in range(300-1):
            temp = self.data[i][j][:]
            n = np.argmax(temp)
            flag = 0
            if n == 0:
                flag = 0
            elif n == 1:
                flag = 1
            elif n == 2:
                flag = 1
            elif n == 3:
                flag = 1
            print(flag)


def check_data(self):
    # check in each row the data is like 1,0,0,0 or 0,1,0,0 or 0,0,1,0
    correct = []
    falseRows = []
    for i in range(len(self.data)):
        for j in range(300-1):
            control = self.data[i,j,:]
            flag = False
            for value in control:
                if value == 1:
                    if flag:
                        falseRows.append((i,j))
                        correct.append(False)
                        break
                    flag = True
            if flag:
                correct.append(True)
            else:
                correct.append(False)
                falseRows.append((i, j))
    return False if False in correct else True


def train_WEnhancerGAN(self):

    loader = DataLoader(self.data, batch_size=self.batch_size, drop_last=True)

    self.g_loss_a = []
    self.d_loss_a = []
    self.w_dist_a = []
    self.gp_a = []


    # define the epochs
    epochs = 10


    # define the lists for d loss on real and fake data
    d_fake_losses, d_real_losses = [], []

    # list for gradient penalties
    gradient_penalties = []
    counter = 0
    latent_dimensions = 128

    for epoch in tqdm(range(10)):
        for i,batch in enumerate(loader):

            # perform label smoothing
            #noise_label = torch.randn(self.batch_size) #* 0.1

            # assign the labels
            real_labels = (torch.ones(self.batch_size)).to(self.device)
            fake_labels = - torch.ones(self.batch_size).to(self.device)


            ###### ** train the discriminator ** #######

            # avg discriminator loss
            d_loss_avg = 0
            real_seqs = batch.type(torch.FloatTensor).to(self.device)
            for _ in range (self.discriminator_steps):

                # set the current gradient to zero
                self.d_optim.zero_grad()

                # generate sequences from the latent space
                latent_vectors = torch.randn(self.batch_size, latent_dimensions).to(self.device)
                #print("booboloski", latent_vectors.size())

                fake_seqs = self.Generator(latent_vectors)


                # score the sequences
                real_score = self.Discriminator(real_seqs)
                fake_score = self.Discriminator(fake_seqs.detach())

                # calculate the gradient penalt+
                gradient_penalty = self.Gradient_Norm(real_seqs, fake_seqs).mean() 

                # discirminator loss
                d_loss =  self.Wasserstein_Loss(real_labels, real_score) - self.Wasserstein_Loss(fake_labels, fake_score) + gradient_penalty * self.gpweight

                # calc grads
                d_loss.backward()

                # apply the grads to the weights
                self.d_optim.step()

                # append the loss
                # d_loss_avg += d_loss
                self.d_loss_a.append(d_loss.item())


            ######  *** train the generator *** #####
            #set gradients to zero
            self.g_optim.zero_grad()
            # generate images from the latent space
            latent_vectors = torch.randn(self.batch_size, latent_dimensions).to(self.device)
            fake_seqs = self.Generator(latent_vectors)

            fake_s = self.Discriminator(fake_seqs)
            g_loss = self.Wasserstein_Loss(fake_labels, fake_s)

            g_loss.backward()
            self.g_optim.step()

            self.g_loss_a.append(g_loss.item())


            # append the g loss in the list
            #self.g_loss_a.append(g_loss.item())
        print("Generator's Loss = ", self.g_loss_a[-1])
        print("Generator's Loss = ", self.g_loss_a[-1], "Discriminator's Loss:", self.d_loss_a[-1])

args = parameter_parser()
wgan = WEnhancerGAN(args)
#print(“single tensor shape:”,wgan.data[0].shape)
wgan.build_GAN_model()
#print(“new data shape:”, wgan.data.shape)
print(wgan.check_data())
wgan.train_WEnhancerGAN()
#wgan.check_ahtisham()

‘’’

plt.plot(wgan.g_loss_a)
plt.plot(wgan.d_loss_a)
plt.plot(wgan.gp_a)
plt.show()

latent_vector = torch.randn(size=(128,)).to(“cuda:3”)
print(wgan.Generator(latent_vector))
‘’’

Any help would be appreciated. The input data is one hot encoded (64, 300, 4)(batch, len, one_hot_channels)