Training time too high when tensorflow code converted to pytorch

I converted a Tensorflow code to pytorch. However, when I printed the number of trainable model parameters, the pytorch version is showing just half the number of parameters as the original tensorflow code. Also, the training time has increased three times for the same hyperparameters. @ptrblck or anyone expert here, can you kindly take a look if I did something redundant, especially the part where I am initializing the model and the part where I am calculating the gradients. The backpropagation part is done like this.

Backpropagation

optimizer.zero_grad()
tot_loss.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
optimizer.step()

Thank you so much!

The original tensorflow code can be found here:

Pytorch code:

import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
from TS_datasets import getBlood
import time
import numpy as np
import matplotlib.pyplot as plt
from utils import classify_with_knn, interp_data, mse_and_corr, dim_reduction_plot
import math
from scipy import stats
import scipy
import os
import datetime
from scipy.stats import gaussian_kde
from math import sqrt
from math import log
import tensorflow_probability as tfp
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log
from scipy.stats import gaussian_kde

from tensorflow import keras as K

dim_red = 1 # perform PCA on the codes and plot the first two components
plot_on = 1 # plot the results, otherwise only textual output is returned
interp_on = 0 # interpolate data (needed if the input time series have different length)
tied_weights = 0 # train an AE where the decoder weights are the econder weights transposed
lin_dec = 1 # train an AE with linear activations in the decoder

parse input data

parser = argparse.ArgumentParser()
parser.add_argument("–code_size", default=20, help=“size of the code”, type=int)
parser.add_argument("–w_reg", default=0.001, help=“weight of the regularization in the loss function”, type=float)
parser.add_argument("–a_reg", default=0.2, help=“weight of the kernel alignment”, type=float)
parser.add_argument("–num_epochs", default=5000, help=“number of epochs in training”, type=int)
parser.add_argument("–batch_size", default=25, help=“number of samples in each batch”, type=int)
parser.add_argument("–max_gradient_norm", default=1.0, help=“max gradient norm for gradient clipping”, type=float)
parser.add_argument("–learning_rate", default=0.001, help=“Adam initial learning rate”, type=float)
parser.add_argument("–hidden_size", default=30, help=“size of the code”, type=int)
args = parser.parse_args()
print(args)

================= DATASET =================

(train_data, train_labels, train_len, _, K_tr,
valid_data, _, valid_len, _, K_vs,
test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel=‘TCK’,
inp=‘zero’) # data shape is [T, N, V] = [time_steps, num_elements, num_var]

sort test data (for a better visualization of the inner product of the codes)

test_data = test_data_orig
train_data = train_data
valid_data = valid_data
test_data = test_data

print(
‘\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n’.format(train_data.shape, valid_data.shape, test_data.shape))

input_length = train_data.shape[1] # same for all inputs

================= GRAPH =================

device = “cuda” if torch.cuda.is_available() else “cpu”
print(f"Using {device} device")

encoder_inputs = train_data
prior_k = K_tr

# ----- ENCODER -----

input_length = encoder_inputs.shape[1]

class Model(nn.Module):
def init(self):
super(Model, self).init()

    self.We1 = torch.nn.Parameter(torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length), 1.0 / math.sqrt(input_length)))
    self.We2 = torch.nn.Parameter(torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size), 1.0 / math.sqrt(args.hidden_size)))

    self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
    self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))

def encode_only(self, train_data):
    hidden_1 = torch.tanh(torch.matmul(train_data.float(), self.We1) + self.be1)
    code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
    return code

def kernel_compute(self, code):
    code_K = torch.mm(code, torch.t(code))
    return  code_K



def encoder_decoder(self,encoder_inputs):
    hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
    code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)

    # ----- DECODER -----
    if tied_weights:

        Wd1 = torch.transpose(We2)
        Wd2 = torch.transpose(We1)

    else:

        Wd1 = torch.nn.Parameter(
            torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
                                                                       1.0 / math.sqrt(args.code_size)))
        Wd2 = torch.nn.Parameter(
            torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                                         1.0 / math.sqrt(args.hidden_size)))

        bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        bd2 = torch.nn.Parameter(torch.zeros([input_length]))

        if lin_dec:
            hidden_2 = torch.matmul(code, Wd1) + bd1
        else:
            hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)

        dec_out = torch.matmul(hidden_2, Wd2) + bd2

        return  dec_out

def loss(self,code, prior_K):

    #print("CODE SHAPE:", code.size())
    # kernel on codes
    code_K = torch.mm(code,torch.t(code)) # dot product of tensors
    #print ("CODE_K SHAPE:", code_K.size() )

    # ----- LOSS -----
    # kernel alignment loss with normalized Frobenius norm
    code_K_norm = code_K / torch.linalg.matrix_norm(code_K, ord='fro', dim=(- 2, - 1))
    prior_K_norm = prior_K / torch.linalg.matrix_norm(prior_K, ord='fro', dim=(- 2, - 1))
    k_loss = torch.linalg.matrix_norm(torch.sub(code_K_norm,prior_K_norm), ord='fro', dim=(- 2, - 1))


    return k_loss

Initialize model

model = Model()

L2 loss

reg_loss = 0

parameters = torch.nn.utils.parameters_to_vector(model.parameters())
#print (“PARAMS:”, (parameters))

optimizer = torch.optim.Adam(model.parameters(),args.learning_rate)

trainable parameters count

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(‘Total parameters: {}’.format(total_params))

============= TENSORBOARD =============

writer = SummaryWriter()

================= TRAINING =================

initialize training variables

time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = “logs/dkae_models/m_0.ckpt”

logdir = os.path.join(“logs”, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

try:
for ep in range(args.num_epochs):

    # shuffle training data
    idx = np.random.permutation(train_data.shape[0])
    train_data_s = train_data[idx, :]
    K_tr_s = K_tr[idx, :][:, idx]


    for batch in range(max_batches):
        fdtr = {}
        fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
        fdtr["prior_K"] =  K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
                         (batch) * batch_size:(batch + 1) * batch_size]

        encoder_inputs = (fdtr["encoder_inputs"].astype(float))
        encoder_inputs = torch.from_numpy(encoder_inputs)
        #print("TYPE ENCODER_INP IN TRAIN:", type(encoder_inputs))

        prior_K = (fdtr["prior_K"].astype(float))
        prior_K = torch.from_numpy(prior_K)

        dec_out = model.encoder_decoder(encoder_inputs)
        #print("DEC OUT TRAIN:", dec_out)

        reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
        reconstruct_loss = reconstruct_loss.float()
        #print("RECONS LOSS TRAIN:", reconstruct_loss)

        k_loss = model.loss(dec_out,prior_K)
        k_loss = k_loss.float()
        #print ("K_LOSS:", k_loss)

        tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
        tot_loss = tot_loss.float()
        #print("TOT_LOSS:", (tot_loss))

    

        # Backpropagation
        optimizer.zero_grad()
        tot_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
        optimizer.step()


        # clipped_gradients, _ = torch.nn.utils.clip_grad_norm(gradients, args.max_gradient_norm)
        # update_step = optimizer.apply_gradients(zip(clipped_gradients, parameters))

        train_loss = reconstruct_loss
        train_kloss = k_loss

        loss_track.append(train_loss)
        kloss_track.append(train_kloss)

    #check training progress on the validations set (in blood data valid=train)
    if ep % 100 == 0:
        print('Ep: {}'.format(ep))

        fdvs = {}
        fdvs["encoder_inputs"] = valid_data
        fdvs["prior_K"] = K_vs


        #outvs, lossvs, klossvs, vs_code_K, summary = sess.run(
         #   [dec_out, reconstruct_loss, k_loss, code_K, merged_summary], fdvs)

        encoder_inp = (fdvs["encoder_inputs"].astype(float))
        encoder_inp = torch.from_numpy(encoder_inp)

        prior_K = (fdvs["prior_K"].astype(float))
        prior_K = torch.from_numpy(prior_K)


        dec_out_val = model.encoder_decoder(encoder_inp)
        #print ("DEC OUT VAL:", dec_out_val)

        reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)

        k_loss = model.loss(dec_out_val,prior_K)
        tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
        tot_loss = tot_loss.float()

        writer.add_scalar("reconstruct_loss", reconstruct_loss, ep)
        writer.add_scalar("k_loss", k_loss, ep)
        #writer.add_scalar("tot_loss", tot_loss, ep)

        outvs = dec_out_val
        lossvs = reconstruct_loss
        klossvs = k_loss

        #code_K = torch.tensordot(dec_out_val, torch.transpose(dec_out_val), axes=1)
        code_K = torch.mm(dec_out_val, torch.t(dec_out_val))
        vs_code_K = code_K

        # print("LOSS TRACK:", (loss_track))
        # print("KLOSS TRACK:", (loss_track))

        print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
        lossvs, klossvs, torch.mean(torch.stack(loss_track[-100:])), torch.mean(torch.stack(kloss_track[-100:]))))

        # Save model yielding best results on validation
        if lossvs < min_vs_loss:
            min_vs_loss = lossvs
            torch.save(model, model_dir)
            torch.save(model.state_dict(), 'logs/dkae_models/best-model-parameters.pt')

            #save_path = saver.save(sess, model_name)

except KeyboardInterrupt:
print(‘training interrupted’)

time_tr_end = time.time()
print(‘Tot training time: {}’.format((time_tr_end - time_tr_start) // 60))

================= TEST =================

print(’************ TEST ************ \n>>restoring from:’ + model_dir + ‘<<’)

model_test = torch.load(model_dir)
#tr_code = sess.run(code, {encoder_inputs: train_data})

encoder_inputs_tr = train_data
encoder_inputs_tr = encoder_inputs_tr.astype(float)
encoder_inputs_tr = torch.from_numpy(encoder_inputs_tr)

tr_code = model.encode_only(encoder_inputs_tr)

encoder_inputs_te = test_data
encoder_inputs_te = encoder_inputs_te.astype(float)
encoder_inputs_te = torch.from_numpy(encoder_inputs_te)

ts_code = model.encode_only(encoder_inputs_te)
ts_code_K = model.kernel_compute(ts_code)

dec_out = model.encoder_decoder(encoder_inputs_te)

#reconstruct_loss = torch.nn.MSELoss()(encoder_inputs_te, dec_out)
reconstruct_loss = torch.mean((dec_out - encoder_inputs_te) ** 2)

pred = dec_out
pred_loss = reconstruct_loss

pred = pred.detach().numpy()

print(“Pred shape:”, pred.shape)
print(‘Test loss: %.3f’ % (np.mean((pred - test_data) ** 2)))
print(“TS code shape:”, ts_code.shape)
print(“TS code K shape:”, ts_code_K.shape)

reverse transformations

print(“Test data shape:”, test_data_orig.shape)

pred = np.reshape(pred, (test_data_orig.shape[1], test_data_orig.shape[0], test_data_orig.shape[2]))

pred = np.reshape(pred, (test_data_orig.shape[1], test_data_orig.shape[0]))
#print(“Pred shape after reshaping:”, pred.shape)

pred = np.transpose(pred,axes=[1,0,2])

pred = np.transpose(pred, axes=[1, 0])
print(“Pred shape after transposing:”, pred.shape)
test_data = test_data_orig
print(“test data shape:”, test_data.shape)

MSE and corr

test_mse, test_corr = mse_and_corr(test_data, pred, test_len)
print(‘Test MSE: %.3f\nTest Pearson correlation: %.3f’ % (test_mse, test_corr))

kNN classification on the codes

acc, f1, auc = classify_with_knn(tr_code, train_labels[:, 0], ts_code, test_labels[:, 0], k=1)

print(“train labels shape:”, train_labels.shape)

tr_code = tr_code.detach().numpy()
ts_code = ts_code.detach().numpy()

acc, f1, auc = classify_with_knn(tr_code, train_labels[:], ts_code, test_labels[:], k=1)
print(‘kNN – acc: %.3f, F1: %.3f, AUC: %.3f’ % (acc, f1, auc))

dim reduction plots

if dim_red:
dim_reduction_plot(ts_code, test_labels, 1)

train_writer.close()

writer.close()

This line of code:

tot_loss.backward(retain_graph=True)

is usually unnecessary and often used to mask another error.
Could you describe why you are retaining the graph as it would keep the entire computation graph alive, would increase the memory usage, and slow down your code?

hi @ptrblck , thank you for checking this. I also thought about the same, but if I do not retain the computation graph, then it throws me an error like this. Can you tell me if there is any other way I can do this. In order to calculate the gradient descent, I believe I will need to save the gradients from the last epoch some way. Please correct me if I am wrong. Thank you again!

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

No, in a standard SGD approach your training loop would use:

  • a forward pass to calculate the output and store the intermediate activations
  • loss calculation using the current output and the corresponding target
  • backward pass computing the gradients (using the loss, parameters, and stored intermediate activations from the forward pass)
  • parameter update using the computed gradients

During the gradient calculation in the backward pass, the intermediate forward activations are used and freed afterwards. Calling backward a second time would thus result in an error you are seeing.

Your current code is hard to read as you didn’t format it properly (you can wrap it into three backticks ```), but I would start with removing the code snippets which keep references to any outputs or losses such as:

        train_loss = reconstruct_loss
        train_kloss = k_loss

        loss_track.append(train_loss)
        kloss_track.append(train_kloss)

This will not only increase the memory usage in each iteration but will try to keep the computation graph alive. I don’t know if this is causing the error as your code is also not executable. If not, try to post a minimal, executable code snippet showing the error.

Hi @ptrblck , thanks again for your kind feedback. I have removed those lines with multiple variable assignments. But it did not reduce training time. Also, the loss is also stuck which means I have done something wrong while converting the TF code. I have copied the original TF code in the post. It’s also in github :

The differences I am observing are:

  1. Although I have saved the initializations of the weights and the biases of the encoder and decoder as torch parameters, the total trainable parameters are just half than the TF code.

  2. The part where I am calling the optimizer and doing the backprop. It only works when I use “retain_graph = True” flag. This I guess is taking the huge time in the training.

  3. Again, I have two loss components, reconstruction loss and a kernel alignment loss. The TF code nicely optimizes both losses. Here, I see both of them stuck and then starts increasing. I have calculated the reconstruction loss without any libraries and the kernel loss is a Frobenious distance. I am calling the backward() on the total loss after adding them up. There is a regularization constant “reg_loss” which I am multiplying with while calculating the total loss. This parameter depends on the number of trainable parameters.
    TF code:

reg_loss = 0
for tf_var in tf.trainable_variables():
    reg_loss += tf.reduce_mean(tf.nn.l2_loss(tf_var))

tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss

Pytorch code:

reg_loss = 0

parameters = torch.nn.utils.parameters_to_vector(model.parameters())
#print ("PARAMS:", (parameters))
for tf_var in parameters:
    reg_loss += torch.mean(torch.linalg.norm(tf_var))

tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss

I believe this is where something I did seriously wrong, for which the loss is not optimizing. Request you kindly to have a look.

import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
from TS_datasets import getBlood
import time
import numpy as np
import matplotlib.pyplot as plt
from utils import classify_with_knn, interp_data, mse_and_corr, dim_reduction_plot
import math
from scipy import stats
import scipy
import os
import datetime
from scipy.stats import gaussian_kde
from math import sqrt
from math import log
import tensorflow_probability as tfp
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log
from scipy.stats import gaussian_kde


# from tensorflow import keras as K

dim_red = 1  # perform PCA on the codes and plot the first two components
plot_on = 1  # plot the results, otherwise only textual output is returned
interp_on = 0  # interpolate data (needed if the input time series have different length)
tied_weights = 0  # train an AE where the decoder weights are the econder weights transposed
lin_dec = 1  # train an AE with linear activations in the decoder

# parse input data
parser = argparse.ArgumentParser()
parser.add_argument("--code_size", default=20, help="size of the code", type=int)
parser.add_argument("--w_reg", default=0.001, help="weight of the regularization in the loss function", type=float)
parser.add_argument("--a_reg", default=0.2, help="weight of the kernel alignment", type=float)
parser.add_argument("--num_epochs", default=5000, help="number of epochs in training", type=int)
parser.add_argument("--batch_size", default=25, help="number of samples in each batch", type=int)
parser.add_argument("--max_gradient_norm", default=1.0, help="max gradient norm for gradient clipping", type=float)
parser.add_argument("--learning_rate", default=0.001, help="Adam initial learning rate", type=float)
parser.add_argument("--hidden_size", default=30, help="size of the code", type=int)
args = parser.parse_args()
print(args)

# ================= DATASET =================
(train_data, train_labels, train_len, _, K_tr,
 valid_data, _, valid_len, _, K_vs,
 test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel='TCK',
                                                            inp='zero')  # data shape is [T, N, V] = [time_steps, num_elements, num_var]

# sort test data (for a better visualization of the inner product of the codes)

test_data = test_data_orig
train_data = train_data
valid_data = valid_data
test_data = test_data

print(
    '\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n'.format(train_data.shape, valid_data.shape, test_data.shape))

input_length = train_data.shape[1]  # same for all inputs

# ================= GRAPH =================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

encoder_inputs = train_data
prior_k = K_tr


# # ----- ENCODER -----

input_length = encoder_inputs.shape[1]

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.We1 = torch.nn.Parameter(torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length), 1.0 / math.sqrt(input_length)))
        self.We2 = torch.nn.Parameter(torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size), 1.0 / math.sqrt(args.hidden_size)))

        self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))

    def encode(self, train_data):
        hidden_1 = torch.tanh(torch.matmul(train_data.float(), self.We1) + self.be1)
        code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
        return code

    def kernel_compute(self, code):
        code_K = torch.mm(code, torch.t(code))
        return  code_K


    def encoder_decoder(self,encoder_inputs):
        hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
        code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)

        # ----- DECODER -----
        if tied_weights:

            Wd1 = torch.transpose(We2)
            Wd2 = torch.transpose(We1)

        else:

            Wd1 = torch.nn.Parameter(
                torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
                                                                           1.0 / math.sqrt(args.code_size)))
            Wd2 = torch.nn.Parameter(
                torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                                             1.0 / math.sqrt(args.hidden_size)))

            bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
            bd2 = torch.nn.Parameter(torch.zeros([input_length]))

            if lin_dec:
                hidden_2 = torch.matmul(code, Wd1) + bd1
            else:
                hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)

            dec_out = torch.matmul(hidden_2, Wd2) + bd2

            return  dec_out

    def loss(self,code, prior_K):
        
        # kernel on codes
        code_K = torch.mm(code,torch.t(code)) # dot product of tensors

        # ----- LOSS -----
        # kernel alignment loss with normalized Frobenius norm
        code_K_norm = code_K / torch.linalg.matrix_norm(code_K, ord='fro', dim=(- 2, - 1))
        prior_K_norm = prior_K / torch.linalg.matrix_norm(prior_K, ord='fro', dim=(- 2, - 1))
        k_loss = torch.linalg.matrix_norm(torch.sub(code_K_norm,prior_K_norm), ord='fro', dim=(- 2, - 1))

        return k_loss

# Initialize model
model = Model()

# L2 loss
reg_loss = 0

parameters = torch.nn.utils.parameters_to_vector(model.parameters())
#print ("PARAMS:", (parameters))
for tf_var in parameters:
    reg_loss += torch.mean(torch.linalg.norm(tf_var))

#tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss

# trainable parameters count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters: {}'.format(total_params))

#Optimizer
optimizer = torch.optim.Adam(model.parameters(),args.learning_rate)

# ============= TENSORBOARD =============
writer = SummaryWriter()

# ================= TRAINING =================

# initialize training variables
time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = "logs/dkae_models/m_0.ckpt"

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))


try:
    for ep in range(args.num_epochs):

        # shuffle training data
        idx = np.random.permutation(train_data.shape[0])
        train_data_s = train_data[idx, :]
        K_tr_s = K_tr[idx, :][:, idx]


        for batch in range(max_batches):
            fdtr = {}
            fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
            fdtr["prior_K"] =  K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
                             (batch) * batch_size:(batch + 1) * batch_size]

            encoder_inputs = (fdtr["encoder_inputs"].astype(float))
            encoder_inputs = torch.from_numpy(encoder_inputs)
            #print("TYPE ENCODER_INP IN TRAIN:", type(encoder_inputs))

            prior_K = (fdtr["prior_K"].astype(float))
            prior_K = torch.from_numpy(prior_K)

            dec_out = model.encoder_decoder(encoder_inputs)
            #print("DEC OUT TRAIN:", dec_out)

            reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
            reconstruct_loss = reconstruct_loss.float()
            #print("RECONS LOSS TRAIN:", reconstruct_loss)

            k_loss = model.loss(dec_out,prior_K)
            k_loss = k_loss.float()
            #print ("K_LOSS:", k_loss)

            tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
            tot_loss = tot_loss.float()

            # Backpropagation
            optimizer.zero_grad()
            tot_loss.backward(retain_graph=True)
            #tot_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()
            
            #train_loss = reconstruct_loss
            #train_kloss = k_loss

            loss_track.append(reconstruct_loss)
            kloss_track.append(k_loss)

        #check training progress on the validations set (in blood data valid=train)
        if ep % 100 == 0:
            print('Ep: {}'.format(ep))

            # fdvs = {"encoder_inputs": valid_data,
            #         "prior_K": K_vs}

            fdvs = {}
            fdvs["encoder_inputs"] = valid_data
            fdvs["prior_K"] = K_vs


            #dec_out_val, lossvs, klossvs, vs_code_K, summary = sess.run(
             #   [dec_out, reconstruct_loss, k_loss, code_K, merged_summary], fdvs)

            encoder_inp = (fdvs["encoder_inputs"].astype(float))
            encoder_inp = torch.from_numpy(encoder_inp)

            prior_K = (fdvs["prior_K"].astype(float))
            prior_K = torch.from_numpy(prior_K)


            dec_out_val = model.encoder_decoder(encoder_inp)
            #print ("DEC OUT VAL:", dec_out_val)

            reconstruct_loss_val = torch.mean((dec_out - encoder_inputs) ** 2)

            k_loss_val = model.loss(dec_out_val,prior_K)
            tot_loss = reconstruct_loss_val + args.w_reg * reg_loss + args.a_reg * k_loss_val
            tot_loss = tot_loss.float()

            writer.add_scalar("reconstruct_loss", reconstruct_loss_val, ep)
            writer.add_scalar("k_loss", k_loss_val, ep)
            #writer.add_scalar("tot_loss", tot_loss, ep)

            # dec_out_val = dec_out_val
            # reconstruct_loss_val = reconstruct_loss_val
            # k_loss_val = k_loss_val

            #code_K = torch.tensordot(dec_out_val, torch.transpose(dec_out_val), axes=1)
            code_K = torch.mm(dec_out_val, torch.t(dec_out_val))
            vs_code_K = code_K

            print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
            reconstruct_loss_val, k_loss_val, torch.mean(torch.stack(loss_track[-100:])), torch.mean(torch.stack(kloss_track[-100:]))))

            # Save model yielding best results on validation
            if reconstruct_loss_val < min_vs_loss:
                min_vs_loss = reconstruct_loss_val
                torch.save(model, model_dir)
                torch.save(model.state_dict(), 'logs/dkae_models/best-model-parameters.pt')

                #save_path = saver.save(sess, model_name)

except KeyboardInterrupt:
    print('training interrupted')

time_tr_end = time.time()
print('Tot training time: {}'.format((time_tr_end - time_tr_start) // 60))

# ================= TEST =================
print('************ TEST ************ \n>>restoring from:' + model_dir + '<<')

model_test = torch.load(model_dir)
#tr_code = sess.run(code, {encoder_inputs: train_data})

# encoder_inputs_tr = train_data
# encoder_inputs_tr = encoder_inputs_tr.astype(float)
# encoder_inputs_tr = torch.from_numpy(encoder_inputs_tr)

train_data = torch.from_numpy(train_data)
tr_code = model.encode(train_data.float())

# encoder_inputs_te = test_data
# encoder_inputs_te = encoder_inputs_te.astype(float)
# encoder_inputs_te = torch.from_numpy(encoder_inputs_te)

test_data = torch.from_numpy(test_data)
ts_code = model.encode(test_data.float())
ts_code_K = model.kernel_compute(ts_code)

#dec_out = model.encoder_decoder(encoder_inputs_te)
pred = model.encoder_decoder(test_data.float())

#reconstruct_loss = torch.nn.MSELoss()(encoder_inputs_te, dec_out)
reconstruct_loss_te = torch.mean((pred - test_data.float()) ** 2)

#pred = dec_out
#pred_loss = reconstruct_loss_te

pred = pred.detach().numpy()

print("Pred shape:", pred.shape)
print('Test loss: %.3f' % (np.mean((pred - test_data) ** 2)))
print("TS code shape:", ts_code.shape)
print("TS code K shape:", ts_code_K.shape)

# reverse transformations
print("Test data shape:", test_data_orig.shape)
# pred = np.reshape(pred, (test_data_orig.shape[1], test_data_orig.shape[0], test_data_orig.shape[2]))
pred = np.reshape(pred, (test_data_orig.shape[1], test_data_orig.shape[0]))
#print("Pred shape after reshaping:", pred.shape)
# pred = np.transpose(pred,axes=[1,0,2])

pred = np.transpose(pred, axes=[1, 0])
print("Pred shape after transposing:", pred.shape)
test_data = test_data_orig
print("test data shape:", test_data.shape)

# MSE and corr
test_mse, test_corr = mse_and_corr(test_data, pred, test_len)
print('Test MSE: %.3f\nTest Pearson correlation: %.3f' % (test_mse, test_corr))

# kNN classification on the codes
# acc, f1, auc = classify_with_knn(tr_code, train_labels[:, 0], ts_code, test_labels[:, 0], k=1)
print("train labels shape:", train_labels.shape)

tr_code = tr_code.detach().numpy()
ts_code = ts_code.detach().numpy()

acc, f1, auc = classify_with_knn(tr_code, train_labels[:], ts_code, test_labels[:], k=1)
print('kNN -- acc: %.3f, F1: %.3f, AUC: %.3f' % (acc, f1, auc))

# dim reduction plots
if dim_red:
    dim_reduction_plot(ts_code, test_labels, 1)

writer.close()

Hi @ptrblck , I checked the code again and everything looks fine to me. I just did a line by line conversion of the TF code. The only two things I want you to kindly have a look are the initialization of the weights and bias parameters and the backpropagation algorithm during training.

The trained model parameters’ count is just half of what I get from TF code. The backprop doesn’t work without the retain_graph = True flag. Also, I have converted numpy to tensors and back to numpy at many places. Do you think that also result consuming more time during training? Thank you again!

hi @ptrblck , will you have some time to kindly take a look at my code. Thank you!

Hi, @Padmaksha_Roy

Thanks for your question here. Have you tried to use profiler to better help you understand where the bottleneck is for your code? (torch.profiler — PyTorch 1.11.0 documentation) Thanks!

hi @fduwjj , the code now works without the retain graph = True flag after I declared one of the variables inside the training loop instead of declaring it outside the loop. But the training time is still very high compared to the TF code. Any feedback to debug this will be very helpful!

hi @ptrblck , here is a sample code to reproduce the problem. The training time is still very high. Kindly let me know what is the problem here.

import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
import time
import matplotlib.pyplot as plt
import math
from scipy import stats
import scipy
import os
import datetime
from math import sqrt
from math import log
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log



# from tensorflow import keras as K

# dim_red = 1  # perform PCA on the codes and plot the first two components
# plot_on = 1  # plot the results, otherwise only textual output is returned
# interp_on = 0  # interpolate data (needed if the input time series have different length)
# tied_weights = 0  # train an AE where the decoder weights are the econder weights transposed
# lin_dec = 1  # train an AE with linear activations in the decoder

# parse input data
parser = argparse.ArgumentParser()
parser.add_argument("--code_size", default=20, help="size of the code", type=int)
parser.add_argument("--w_reg", default=0.001, help="weight of the regularization in the loss function", type=float)
parser.add_argument("--a_reg", default=0.2, help="weight of the kernel alignment", type=float)
parser.add_argument("--num_epochs", default=5000, help="number of epochs in training", type=int)
parser.add_argument("--batch_size", default=25, help="number of samples in each batch", type=int)
parser.add_argument("--max_gradient_norm", default=1.0, help="max gradient norm for gradient clipping", type=float)
parser.add_argument("--learning_rate", default=0.001, help="Adam initial learning rate", type=float)
parser.add_argument("--hidden_size", default=30, help="size of the code", type=int)
args = parser.parse_args()
print(args)

# ================= DATASET =================
# (train_data, train_labels, train_len, _, K_tr,
#  valid_data, _, valid_len, _, K_vs,
#  test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel='TCK',
#                                                             inp='zero')  # data shape is [T, N, V] = [time_steps, num_elements, num_var]

train_data = np.random.rand(9000,6)
train_labels = np.ones([9000,1])
train_len = 9000

valid_data = np.random.rand(9000,6)
valid_len = 9000

test_data = np.random.rand(1500,6)
test_labels = np.ones([1500,1])

K_tr = np.random.rand(9000,9000)
K_ts = np.random.rand(1500,1500)
K_vs =  np.random.rand(9000,9000)

#test_data = test_data_orig


print(
    '\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n'.format(train_data.shape, valid_data.shape, test_data.shape))

input_length = train_data.shape[1]  # same for all inputs

# ================= GRAPH =================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

encoder_inputs = train_data
prior_k = K_tr

# ============= TENSORBOARD =============
writer = SummaryWriter()

# # ----- ENCODER -----

input_length = encoder_inputs.shape[1]
print ("INPUT ")

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.We1 = torch.nn.Parameter(torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length), 1.0 / math.sqrt(input_length)))
        self.We2 = torch.nn.Parameter(torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size), 1.0 / math.sqrt(args.hidden_size)))

        self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))


    def encoder(self, encoder_inputs):
        hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
        code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
        #print ("CODE ENCODER SHAPE:", code.size())
        return code

    def decoder(self,encoder_inputs):
        code = self.encoder(encoder_inputs)

        Wd1 = torch.nn.Parameter(
            torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
                                                                       1.0 / math.sqrt(args.code_size)))
        Wd2 = torch.nn.Parameter(
            torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                                         1.0 / math.sqrt(args.hidden_size)))

        bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        bd2 = torch.nn.Parameter(torch.zeros([input_length]))


        #if lin_dec:
        #hidden_2 = torch.matmul(code, Wd1) + bd1
        #else:
        hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)

        #print("hidden SHAPE:", hidden_2.size())
        dec_out = torch.matmul(hidden_2, Wd2) + bd2

        return  dec_out

    def kernel_loss(self,code, prior_K):
        # kernel on codes
        code_K = torch.mm(code, torch.t(code))

        # ----- LOSS -----
        # kernel alignment loss with normalized Frobenius norm
        code_K_norm = code_K / torch.linalg.matrix_norm(code_K, ord='fro', dim=(- 2, - 1))
        prior_K_norm = prior_K / torch.linalg.matrix_norm(prior_K, ord='fro', dim=(- 2, - 1))
        k_loss = torch.linalg.matrix_norm(torch.sub(code_K_norm,prior_K_norm), ord='fro', dim=(- 2, - 1))
        return k_loss


# Initialize model
model = Model()

# trainable parameters count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters: {}'.format(total_params))

#Optimizer
optimizer = torch.optim.Adam(model.parameters(),args.learning_rate)

# ================= TRAINING =================

# initialize training variables
time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = "logs/m_0.ckpt"

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

###############################################################################
# Training code
###############################################################################

try:
    for ep in range(args.num_epochs):

        # shuffle training data
        idx = np.random.permutation(train_data.shape[0])
        train_data_s = train_data[idx, :]
        K_tr_s = K_tr[idx, :][:, idx]


        for batch in range(max_batches):
            fdtr = {}
            fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
            fdtr["prior_K"] =  K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
                             (batch) * batch_size:(batch + 1) * batch_size]

            encoder_inputs = (fdtr["encoder_inputs"].astype(float))
            encoder_inputs = torch.from_numpy(encoder_inputs)
            #print("TYPE ENCODER_INP IN TRAIN:", type(encoder_inputs))

            prior_K = (fdtr["prior_K"].astype(float))
            prior_K = torch.from_numpy(prior_K)

            dec_out = model.decoder(encoder_inputs)

            #print("DEC OUT TRAIN:", dec_out)


            reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
            reconstruct_loss = reconstruct_loss.float()
            #print("RECONS LOSS TRAIN:", reconstruct_loss)

            enc_out = model.encoder(encoder_inputs)
            k_loss = model.kernel_loss(enc_out,prior_K)
            k_loss = k_loss.float()
            #print ("K_LOSS TRAIN:", k_loss)


            #print ("ENTRPY LOSS:", entrpy_loss)

            # Regularization L2 loss
            reg_loss = 0

            parameters = torch.nn.utils.parameters_to_vector(model.parameters())
            # print ("PARAMS:", (parameters))
            for tf_var in parameters:
                reg_loss += torch.mean(torch.linalg.norm(tf_var))

            tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
            tot_loss = tot_loss.float()

            # Backpropagation
            optimizer.zero_grad()
            #tot_loss.backward(retain_graph=True)
            tot_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()

            #tot_loss = tot_loss.detach()

            loss_track.append(reconstruct_loss)
            kloss_track.append(k_loss)

        #check training progress on the validations set (in blood data valid=train)
        if ep % 100 == 0:
            print('Ep: {}'.format(ep))

            # fdvs = {"encoder_inputs": valid_data,
            #         "prior_K": K_vs}

            fdvs = {}
            fdvs["encoder_inputs"] = valid_data
            fdvs["prior_K"] = K_vs


            #dec_out_val, lossvs, klossvs, vs_code_K, summary = sess.run(
             #   [dec_out, reconstruct_loss, k_loss, code_K, merged_summary], fdvs)

            encoder_inp = (fdvs["encoder_inputs"].astype(float))
            encoder_inp = torch.from_numpy(encoder_inp)

            prior_K_vs = (fdvs["prior_K"].astype(float))
            prior_K_vs = torch.from_numpy(prior_K_vs)

            enc_out_vs = model.encoder(encoder_inp)


            dec_out_val = model.decoder(encoder_inp)
            #print ("DEC OUT VAL:", dec_out_val)


            reconstruct_loss_val = torch.mean((dec_out_val - encoder_inp) ** 2)
            #print("RECONS LOSS VAL:", reconstruct_loss)

            k_loss_val = model.kernel_loss(enc_out_vs,prior_K_vs)
            #print("K_LOSS VAL:", k_loss_val)


            writer.add_scalar("reconstruct_loss", reconstruct_loss_val, ep)
            writer.add_scalar("k_loss", k_loss_val, ep)
            #writer.add_scalar("tot_loss", tot_loss, ep)


            print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
            reconstruct_loss_val, k_loss_val, torch.mean(torch.stack(loss_track[-100:])), torch.mean(torch.stack(kloss_track[-100:]))))
            #reconstruct_loss_val, k_loss_val, np.mean(loss_track[-100:].detach().numpy()), np.mean(kloss_track[-100:].detach().numpy())))


            # Save model yielding best results on validation
            if reconstruct_loss_val < min_vs_loss:
                min_vs_loss = reconstruct_loss_val
                torch.save(model, model_dir)
                #torch.save(model.state_dict(), 'logs/dkae_models/best-model-parameters.pt')

                #save_path = saver.save(sess, model_name)

except KeyboardInterrupt:
    print('training interrupted')

time_tr_end = time.time()
print('Tot training time: {}'.format((time_tr_end - time_tr_start) // 60))
writer.close()

Can be executed as:

python3 filename.py --code_size 4 --w_reg 0.001 --a_reg 0.1 --num_epochs 500 --max_gradient_norm 0.5 --learning_rate 0.001 --hidden_size 30 

Thank you!

Recreating new parameters in the forward pass (as done in decoder) wouldn’t make sense as they won’t be trained and their init might also create a performance penalty which could be avoided.
However, I would first recommend to make sure the models are actually the same as described in your cross-post as I’m unsure what the status is of this debugging effort.

Thanks! I will check.

Hi @ptrblck , I have done those changes as you said. But the code still takes almost 3-4 times the TF code for same number of epochs. Here is the sample code below.


import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import stats
import scipy
import os
import datetime
from scipy.stats import gaussian_kde
from math import sqrt
from math import log
from torch import optim
from torch.autograd import Variable
from math import sqrt
from math import log
from sklearn.neighbors import KernelDensity

# from tensorflow import keras as K

# parse input data
parser = argparse.ArgumentParser()
parser.add_argument("--code_size", default=20, help="size of the code", type=int)
parser.add_argument("--w_reg", default=0.001, help="weight of the regularization in the loss function", type=float)
parser.add_argument("--a_reg", default=0.2, help="weight of the kernel alignment", type=float)
parser.add_argument("--num_epochs", default=5000, help="number of epochs in training", type=int)
parser.add_argument("--batch_size", default=25, help="number of samples in each batch", type=int)
parser.add_argument("--max_gradient_norm", default=1.0, help="max gradient norm for gradient clipping", type=float)
parser.add_argument("--learning_rate", default=0.001, help="Adam initial learning rate", type=float)
parser.add_argument("--hidden_size", default=30, help="size of the code", type=int)
args = parser.parse_args()
print(args)

# ================= DATASET =================
# (train_data, train_labels, train_len, _, K_tr,
#  valid_data, _, valid_len, _, K_vs,
#  test_data_orig, test_labels, test_len, _, K_ts) = getBlood(kernel='TCK',
#                                                             inp='zero')  # data shape is [T, N, V] = [time_steps, num_elements, num_var]

train_data = np.random.rand(9000,6)
train_labels = np.ones([9000,1])
train_len = 9000

valid_data = np.random.rand(9000,6)
valid_len = 9000

test_data = np.random.rand(1500,6)
test_labels = np.ones([1500,1])

K_tr = np.random.rand(9000,9000)
K_ts = np.random.rand(1500,1500)
K_vs =  np.random.rand(9000,9000)

#test_data = test_data_orig

print(
    '\n**** Processing Blood data: Tr{}, Vs{}, Ts{} ****\n'.format(train_data.shape, valid_data.shape, test_data.shape))

input_length = train_data.shape[1]  # same for all inputs

# ================= GRAPH =================

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

encoder_inputs = train_data
prior_k = K_tr

# ============= TENSORBOARD =============
writer = SummaryWriter()

# # ----- ENCODER -----

input_length = encoder_inputs.shape[1]
print("INPUT ")


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.We1 = torch.nn.Parameter(
            torch.Tensor(input_length, args.hidden_size).uniform_(-1.0 / math.sqrt(input_length),
                                                                  1.0 / math.sqrt(input_length)))
        self.We2 = torch.nn.Parameter(
            torch.Tensor(args.hidden_size, args.code_size).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                                    1.0 / math.sqrt(args.hidden_size)))

        self.be1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
        self.be2 = torch.nn.Parameter(torch.zeros([args.code_size]))

    def encoder(self, encoder_inputs):
        hidden_1 = torch.tanh(torch.matmul(encoder_inputs.float(), self.We1) + self.be1)
        code = torch.tanh(torch.matmul(hidden_1, self.We2) + self.be2)
        # print ("CODE ENCODER SHAPE:", code.size())
        return code


def decoder(encoder_inputs):
    Wd1 = torch.nn.Parameter(
        torch.Tensor(args.code_size, args.hidden_size).uniform_(-1.0 / math.sqrt(args.code_size),
                                                                1.0 / math.sqrt(args.code_size)))
    Wd2 = torch.nn.Parameter(
        torch.Tensor(args.hidden_size, input_length).uniform_(-1.0 / math.sqrt(args.hidden_size),
                                                              1.0 / math.sqrt(args.hidden_size)))

    bd1 = torch.nn.Parameter(torch.zeros([args.hidden_size]))
    bd2 = torch.nn.Parameter(torch.zeros([input_length]))

    hidden_2 = torch.tanh(torch.matmul(code, Wd1) + bd1)

    dec_out = torch.matmul(hidden_2, Wd2) + bd2

    return dec_out


def kernel_loss(code, prior_K):
    # kernel on codes
    code_K = torch.mm(code, torch.t(code))

    # ----- LOSS -----
    # kernel alignment loss with normalized Frobenius norm
    code_K_norm = code_K / torch.linalg.matrix_norm(code_K, ord='fro', dim=(- 2, - 1))
    prior_K_norm = prior_K / torch.linalg.matrix_norm(prior_K, ord='fro', dim=(- 2, - 1))
    k_loss = torch.linalg.matrix_norm(torch.sub(code_K_norm, prior_K_norm), ord='fro', dim=(- 2, - 1))
    return k_loss


# Initialize model
model = Model()

# trainable parameters count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters: {}'.format(total_params))

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

# ================= TRAINING =================

# initialize training variables
time_tr_start = time.time()
batch_size = args.batch_size
max_batches = train_data.shape[0] // batch_size
loss_track = []
kloss_track = []
min_vs_loss = np.infty
model_dir = "logs/m_0.ckpt"

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

###############################################################################
# Training code
###############################################################################

try:
    for ep in range(args.num_epochs):

        # shuffle training data
        idx = np.random.permutation(train_data.shape[0])
        train_data_s = train_data[idx, :]
        K_tr_s = K_tr[idx, :][:, idx]

        for batch in range(max_batches):
            fdtr = {}
            fdtr["encoder_inputs"] = train_data_s[(batch) * batch_size:(batch + 1) * batch_size, :]
            fdtr["prior_K"] = K_tr_s[(batch) * batch_size:(batch + 1) * batch_size,
                              (batch) * batch_size:(batch + 1) * batch_size]

            encoder_inputs = (fdtr["encoder_inputs"].astype(float))
            encoder_inputs = torch.from_numpy(encoder_inputs)

            prior_K = (fdtr["prior_K"].astype(float))
            prior_K = torch.from_numpy(prior_K)

            code = model.encoder(encoder_inputs)
            dec_out = decoder(code)

            reconstruct_loss = torch.mean((dec_out - encoder_inputs) ** 2)
            reconstruct_loss = reconstruct_loss.float()
            # print("RECONS LOSS TRAIN:", reconstruct_loss)

            k_loss = kernel_loss(code, prior_K)
            k_loss = k_loss.float()

            # Regularization L2 loss
            reg_loss = 0

            parameters = torch.nn.utils.parameters_to_vector(model.parameters())
            # print ("PARAMS:", (parameters))
            for tf_var in parameters:
                reg_loss += torch.mean(torch.linalg.norm(tf_var))

            tot_loss = reconstruct_loss + args.w_reg * reg_loss + args.a_reg * k_loss
            tot_loss = tot_loss.float()

            # Backpropagation
            optimizer.zero_grad()
            # tot_loss.backward(retain_graph=True)
            tot_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()


            loss_track.append(reconstruct_loss)
            kloss_track.append(k_loss)

            # # check training progress on the validations set (in blood data valid=train)
            if ep % 100 == 0:
                print('Ep: {}'.format(ep))
                fdvs = {}
                fdvs["encoder_inputs"] = valid_data
                fdvs["prior_K"] = K_vs

                encoder_inp = (fdvs["encoder_inputs"].astype(float))
                encoder_inp = torch.from_numpy(encoder_inp)

                prior_K_vs = (fdvs["prior_K"].astype(float))
                prior_K_vs = torch.from_numpy(prior_K_vs)

                code_vs = model.encoder(encoder_inp)

                dec_out_val = decoder(code_vs)
                print("DEC OUT VAL:", dec_out_val.size())

                reconstruct_loss_val = torch.mean((dec_out_val - encoder_inp) ** 2)
                # print("RECONS LOSS VAL:", reconstruct_loss)

                k_loss_val = kernel_loss(code_vs, prior_K_vs)

            #
                writer.add_scalar("reconstruct_loss", reconstruct_loss_val, ep)
                writer.add_scalar("k_loss", k_loss_val, ep)
                # writer.add_scalar("tot_loss", tot_loss, ep)

            print('VS r_loss=%.3f, k_loss=%.3f -- TR r_loss=%.3f, k_loss=%.3f' % (
                reconstruct_loss, k_loss, torch.mean(torch.stack(loss_track[-100:])),
                torch.mean(torch.stack(kloss_track[-100:]))))

            # Save model yielding best results on validation
            if reconstruct_loss_val < min_vs_loss:
                min_vs_loss = reconstruct_loss_val
                torch.save(model, model_dir)


except KeyboardInterrupt:
    print('training interrupted')

time_tr_end = time.time()
print('Tot training time: {}'.format((time_tr_end - time_tr_start) // 60))

writer.close()
!python3 filename.py --code_size 4 --w_reg 0.001 --a_reg 0.1 --num_epochs 500 --max_gradient_norm 0.5 --learning_rate 0.001 --hidden_size 30 

Thanks!