GAN with STGCN doesn't learn

Hi :slightly_smiling_face:
I’m trying to build a GAN with the Spatio-Temporal Graph Convolutional Networks (STGCN) to generate timeseries. I’m working with the Metr-La traffic dataset.
The original STGCN was build for traffic forecasting, so I thought it’s possible to build a GAN. I need it for my bachelorthesis and it is planned that it will get an additional multivariat contextvector later. (If you have any ideas for it, I would be really happy :D).
So back to my actual problem. I trained the GAN for around 1000 epochs with a batchsize of 50. I had to reduce the complexity of the original STGCN, because I got a cuda out of memory exception while validation.
I tried many different parameters, activation and loss functions, but the loss is still bad. Do you have any ideas to make the model better?

Here’s my code:

import os, sys, time, datetime
import imageio
import itertools
import argparse
import pickle as pk
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as data

from stgcn import STGCN_D, STGCN_G
from utils import generate_dataset, load_metr_la_data, get_normalized_adj, generate_noise, show_train_hist
from pynvml import *

use_gpu = True
num_timesteps_input = 12
num_timesteps_output = 3

epochs = 1000#1000
batch_size = 50#50
n_critic = 1

# results save folder
if not os.path.isdir('RCGAN_MetrLa_Ergebnisse'):
    os.mkdir('RCGAN_MetrLa_Ergebnisse')


parser = argparse.ArgumentParser(description='STGCN')
parser.add_argument('--enable-cuda', action='store_true',
                    help='Enable CUDA')
args = parser.parse_args()
args.device = None
if args.enable_cuda and torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

def train_epoch(training_input, training_target, batch_size):
    """
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    :return: Average loss for this epoch.
    """
    permutation = torch.randperm(training_input.shape[0])

    epoch_training_Dlosses = []
    epoch_training_Glosses = []
  #  for i in range(0, int(training_input.shape[0]/10), batch_size):
    for i, data in enumerate(train_loader):
        # Train Discriminator
        Discriminator.train()
        D_optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=args.device)
        y_batch = y_batch.to(device=args.device)
        y_0_batch = torch.zeros(y_batch.shape[0], y_batch.shape[1], y_batch.shape[2])
        y_1_batch = torch.ones(y_batch.shape[0], y_batch.shape[1], y_batch.shape[2])
        y_0_batch = y_0_batch.to(device=args.device)
        y_1_batch = y_1_batch.to(device=args.device)

        D_out = Discriminator(y_batch)
        D_x_loss = loss_criterion(D_out, y_1_batch)

        z = generate_noise(batch_size, A_wave, num_timesteps_input).to(device=args.device)
        G_out = Generator(A_wave, z)
        z_out = Discriminator(G_out) 
        D_z_loss = loss_criterion(z_out, y_0_batch)
        D_loss = D_x_loss + D_z_loss
        D_loss.backward()
        D_optimizer.step()
        epoch_training_Dlosses.append(D_loss.detach().cpu().numpy())

        if step % n_critic == 0:
        # Training Generator
            z = generate_noise(batch_size, A_wave, num_timesteps_input).to(device=args.device)
            g_out = Generator(A_wave, z)
            z_outputs = Discriminator(g_out)
            G_loss = loss_criterion(z_outputs, y_batch)

            Generator.zero_grad()
            G_loss.backward()
            G_optimizer.step()

            epoch_training_Glosses.append(G_loss.detach().cpu().numpy())
    return sum(epoch_training_Dlosses)/len(epoch_training_Dlosses), sum(epoch_training_Glosses)/len(epoch_training_Glosses) 


if __name__ == '__main__':
    torch.manual_seed(7)

    A, X, means, stds = load_metr_la_data()

    split_line1 = int(X.shape[2] * 0.6)
    split_line2 = int(X.shape[2] * 0.8)

    train_original_data = X[:, :, :split_line1]
    val_original_data = X[:, :, split_line1:split_line2]
    test_original_data = X[:, :, split_line2:]

    training_input, training_target = generate_dataset(train_original_data,
                                                       num_timesteps_input=num_timesteps_input,
                                                       num_timesteps_output=num_timesteps_output)
    val_input, val_target = generate_dataset(val_original_data,
                                             num_timesteps_input=num_timesteps_input,
                                             num_timesteps_output=num_timesteps_output)
    test_input, test_target = generate_dataset(test_original_data,
                                               num_timesteps_input=num_timesteps_input,
                                               num_timesteps_output=num_timesteps_output)
    train_loader = data.DataLoader(training_input, batch_size = batch_size, shuffle = True, drop_last = True)

    A_wave = get_normalized_adj(A)
    A_wave = torch.from_numpy(A_wave)

    A_wave = A_wave.to(device=args.device)

    Discriminator = STGCN_D(A_wave.shape[0], training_input.shape[3], num_timesteps_input, num_timesteps_output).to(device=args.device)
    Generator = STGCN_G(A_wave.shape[0], training_input.shape[3], num_timesteps_input, num_timesteps_output).to(device=args.device)

    D_optimizer = torch.optim.Adam(Discriminator.parameters(), lr=1e-3)
    G_optimizer = torch.optim.Adam(Generator.parameters(), lr=1e-3)

    loss_criterion = nn.MSELoss()

    step = 1

    training_d_losses = []
    training_g_losses = []
    validation_losses = []
    validation_maes = []

    train_hist = {}
    train_hist['Training_D_losses'] = []
    train_hist['Training_G_losses'] = []
    train_hist['Validation_losses'] = []
    train_hist['MAEs'] = []
    train_hist['per_epoch_ptimes'] = []
    train_hist['total_ptime'] = []

    print('Start training')
    start_time = time.time()
    for epoch in range(epochs):
        epoch_start_time = time.time()
        d_loss, g_loss = train_epoch(training_input, training_target,
                           batch_size=batch_size)
        training_d_losses.append(d_loss)
        training_g_losses.append(g_loss)


        # Run validation
        with torch.no_grad():
            Generator.eval()
            val_input = val_input.to(device=args.device)
            val_target = val_target.to(device=args.device)

            out = Generator(A_wave, val_input)

            val_loss = loss_criterion(out, val_target).to(device="cpu"))

            validation_losses.append(np.asscalar(val_loss.detach().numpy()))

            out_unnormalized = out.detach().cpu().numpy()*stds[0]+means[0]
            target_unnormalized = val_target.detach().cpu().numpy()*stds[0]+means[0]
            mae = np.mean(np.absolute(out_unnormalized - target_unnormalized))
            validation_maes.append(mae)

            out = None
            val_input = val_input.to(device='cuda')#to(device="cpu")
            val_target = val_target.to(device='cuda')#to(device="cpu")
        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('Epoch: {}'.format(epoch))
        print("Training d_loss: {}".format(training_d_losses[-1]))
        print("Training g_loss: {}".format(training_g_losses[-1]))
        print("Validation loss: {}".format(validation_losses[-1]))
        print("Validation MAE: {}".format(validation_maes[-1]))
        print("Epoch Time: {}".format(per_epoch_ptime))

        train_hist['Training_D_losses'].append(torch.mean(torch.Tensor(training_d_losses)))
        train_hist['Training_G_losses'].append(torch.mean(torch.Tensor(training_g_losses)))
        train_hist['Validation_losses'].append(torch.mean(torch.Tensor(validation_losses)))
        train_hist['MAEs'].append(torch.mean(torch.Tensor(validation_maes)))
        train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

        checkpoint_path = "checkpoints/"
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)
        with open("checkpoints/losses.pk", "wb") as fd:
            pk.dump((training_g_losses, validation_losses, validation_maes), fd)
    end_time = time.time()
    total_ptime = end_time-start_time
    train_hist['total_ptime'].append(total_ptime)
    print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), epochs, total_ptime))
    print("Training finish!... save training results")
    torch.save(Generator.state_dict(), "RCGAN_MetrLa_Ergebnisse/generator_param.pkl")
    torch.save(Discriminator.state_dict(), "RCGAN_MetrLa_Ergebnisse/discriminator_param.pkl")
    with open('RCGAN_MetrLa_Ergebnisse/train_hist.pkl', 'wb') as f:
        pk.dump(train_hist, f)

    show_train_hist(train_hist, save=True, path='RCGAN_MetrLa_Ergebnisse/MetrLa_rcGAN_train_hist.png')

Model:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pynvml import *
class TimeBlock(nn.Module):
    """
    Neural network block that applies a temporal convolution to each node of
    a graph in isolation.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3):
        """
        :param in_channels: Number of input features at each node in each time
        step.
        :param out_channels: Desired number of output channels at each node in
        each time step.
        :param kernel_size: Size of the 1D temporal kernel.
        """
        super(TimeBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))

    def forward(self, X):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps,
        num_features=in_channels)
        :return: Output data of shape (batch_size, num_nodes,
        num_timesteps_out, num_features_out=out_channels)
        """
        # Convert into NCHW format for pytorch to perform convolutions.
        X = X.permute(0, 3, 1, 2)
        #self.conv1.cpu()
        out_conv1 = self.conv1(X) 
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(0)
        info = nvmlDeviceGetMemoryInfo(handle)
        out_conv2 = torch.sigmoid(self.conv2(X))
        temp = out_conv1 + out_conv2
        out = F.relu(temp + self.conv3(X))
        # Convert back from NCHW to NHWC
        out = out.permute(0, 2, 3, 1)
        return out


class STGCNBlock(nn.Module):
    """
    Neural network block that applies a temporal convolution on each node in
    isolation, followed by a graph convolution, followed by another temporal
    convolution on each node.
    """

    def __init__(self, in_channels, spatial_channels, out_channels,
                 num_nodes):
        """
        :param in_channels: Number of input features at each node in each time
        step.
        :param spatial_channels: Number of output channels of the graph
        convolutional, spatial sub-block.
        :param out_channels: Desired number of output features at each node in
        each time step.
        :param num_nodes: Number of nodes in the graph.
        """
        super(STGCNBlock, self).__init__()
        self.temporal1 = TimeBlock(in_channels=in_channels,
                                   out_channels=out_channels)
        self.Theta1 = nn.Parameter(torch.FloatTensor(out_channels,
                                                     spatial_channels))
        self.temporal2 = TimeBlock(in_channels=spatial_channels,
                                   out_channels=out_channels)
        self.batch_norm = nn.BatchNorm2d(num_nodes)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Theta1.shape[1])
        self.Theta1.data.uniform_(-stdv, stdv)

    def forward(self, X, A_hat):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps,
        num_features=in_channels).
        :param A_hat: Normalized adjacency matrix.
        :return: Output data of shape (batch_size, num_nodes,
        num_timesteps_out, num_features=out_channels).
        """
        t = self.temporal1(X)
        lfs = torch.einsum("ij,jklm->kilm", [A_hat, t.permute(1, 0, 2, 3)])
        t2 = F.relu(torch.matmul(lfs, self.Theta1))
        t3 = self.temporal2(t2)
        return self.batch_norm(t3)

class STGCN_D(nn.Module):

    def __init__(self, num_nodes, num_features, num_timesteps_input,
                 num_timesteps_output):"
        super(STGCN_D, self).__init__()
        kernel_size = 3
        self.conv1 = nn.Conv1d(1, 50, (1, kernel_size)) #207,3 (1, kernel_size)
        self.leakyrelu1 = nn.LeakyReLU(0.2)
        self.batchnorm1 = nn.BatchNorm1d(3450)
        self.conv2 = nn.Conv1d(3450, 207, (1, kernel_size))
        self.leakyrelu2 = nn.LeakyReLU(0.2)
        self.batchnorm2 = nn.BatchNorm1d(207)
        self.fc1 = nn.Linear(1, 414)
        self.leakyrelu3 = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(414, 3)

    def forward(self, X):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps,
        num_features=in_channels).
        :param A_hat: Normalized adjacency matrix.
        """
        batch_size = 50
        X = X.unsqueeze(0)
        X = X.view(X.size(1), X.size(0), X.size(2), X.size(3))
        X = self.conv1(X)
        X = self.leakyrelu1(X)
        X = X.reshape(batch_size, 3450 , 3)
        X = self.batchnorm1(X)
        X = X.unsqueeze(2)
        X = self.conv2(X)
        X = self.leakyrelu2(X)
        X = X.reshape(batch_size, 207, 1)
        X = self.batchnorm2(X)
        #X = X.view(X.size(0), X.size(2), X.size(1))
        X = self.fc1(X)
        X = self.leakyrelu3(X)
        X = self.fc2(X)
        X = torch.sigmoid(X)
        return X
class STGCN_G(nn.Module):
    """
    Spatio-temporal graph convolutional network as described in
    https://arxiv.org/abs/1709.04875v3 by Yu et al.
    Input should have shape (batch_size, num_nodes, num_input_time_steps,
    num_features).
    """

    def __init__(self, num_nodes, num_features, num_timesteps_input,
                 num_timesteps_output):
        """
        :param num_nodes: Number of nodes in the graph.
        :param num_features: Number of features at each node in each time step.
        :param num_timesteps_input: Number of past time steps fed into the
        network.
        :param num_timesteps_output: Desired number of future time steps
        output by the network.
        """
        # Alle Werte geachtelt
        super(STGCN_G, self).__init__()
        self.block1 = STGCNBlock(in_channels=num_features, out_channels=8,
                                 spatial_channels=2, num_nodes=num_nodes)
        self.block2 = STGCNBlock(in_channels=8, out_channels=8,
                                 spatial_channels=2, num_nodes=num_nodes)
        self.last_temporal = TimeBlock(in_channels=8, out_channels=8)
        self.fully = nn.Linear((num_timesteps_input - 2 * 5) * 8,
                               num_timesteps_output)

    def forward(self, A_hat, X):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps,
        num_features=in_channels).
        :param A_hat: Normalized adjacency matrix.
        """
        out1 = self.block1(X, A_hat)
        out2 = self.block2(out1, A_hat)
        out3 = self.last_temporal(out2)
        out4 = self.fully(out3.reshape((out3.shape[0], out3.shape[1], -1)))
        out4 = torch.tanh(out4) #my idea
        return out4

Some utils:

import os, sys, time, datetime
import imageio
import pickle
import zipfile
import itertools
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow, imsave
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable

def load_metr_la_data():
    if (not os.path.isfile('data/adj_mat.npy')
        or not os.path.isfile('data/node_values.npy')):
        with zipfile.ZipFile('data/METR-LA.zip', 'r') as zip_ref:
            zip_ref.extractall('data/')
    A = np.load('data/adj_mat.npy')
    X = np.load('data/node_values.npy').transpose((1,2,0))
    X = X.astype(np.float32)

    # Normalization using Z-score method
    means = np.mean(X, axis=(0,2))
    X = X - means.reshape(1, -1, 1)
    stds = np. std(X, axis=(0,2))
    X = X / stds.reshape(1,-1,1)

    return A, X, means, stds

def get_normalized_adj(A):
    """
    Returns the degree normalized adjacency matrix.
    """
    A = A + np.diag(np.ones(A.shape[0], dtype=np.float32))
    D = np.array(np.sum(A, axis =1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5 #Prevent infs
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np. multiply(np.multiply(diag.reshape((-1 ,1)), A),
                          diag.reshape((1,-1)))
    return A_wave

def generate_dataset(X, num_timesteps_input, num_timesteps_output):
    """
    Takes node features for the graph and divides them into multiple samples
    along the time-axis by sliding a window of size (num_timesteps_input+
    num_timesteps_output) across it in steps of 1.
    :param X: Node features of shape (num_vertices, num_features,
    num_timesteps)
    :return:
        - Node features divided into multiple samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_input).
        - Node targets for the samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_output).
    """
    # Generate the beginning index and the ending index of a sample, which
    # contains (num_points_for_training + num_points_for_predicting) points
    indices = [(i, i + (num_timesteps_input + num_timesteps_output)) for i
               in range(X.shape[2] - (
                num_timesteps_input + num_timesteps_output) + 1)]

    # Save samples
    features, target = [], []
    for i, j in indices:
        features.append(
            X[:, :, i: i + num_timesteps_input].transpose(
                (0, 2, 1)))
        target.append(X[:, 0, i + num_timesteps_input: j])

    return torch.from_numpy(np.array(features)), \
           torch.from_numpy(np.array(target))

def generate_noise(batch_size, A_wave, input_stepsize):
    noise = torch.randn(batch_size, A_wave.shape[0], input_stepsize, 2)
    return noise

def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['Training_D_losses']))

    y1 = hist['Training_D_losses']
    y2 = hist['Training_G_losses']
    y3 = hist['Validation_losses']
    y4 = hist['MAEs']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')
    plt.plot(x, y3, label = 'Val_loss')
    plt.plot(x, y4, label = 'MAE')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

The validation MAE ist after 1000 epochs around 22, the training d_loss around 0.43, the training g_loss at 1.19 and the validation loss of the genartor around 1.6.