VAE- Gumbel Softmax

I implemented a gumbel-softmax based variational autoencoder following the tensorflow implementation here ( The code appears to work, however the convergence is much slower than with TensorFlow, using the same optimizer (Adam) and learning rate. For instance TensorFlow has already converged after 5000 iterations, whereas my implementation converges much more slowly. The initial value of the loss is almost identical across both implementations, suggesting that my implementation is broadly correct. I also explicitly calculate the binary cross entropy for the decoder to verify that the Bernoulli implementation in the distributions library is correct. Code is attached below, along with the change of loss after 1000 iterations.

I disabled learning rate adjustments and temperature annealing in TF and my implementation to keep things simple. Same annealing temperature was used in both implementations. I also verified that both mnist data used in both implementations is un-normalized floating point data between 0-1.

Welcome any thoughts/suggestions!


from __future__ import print_function
import argparse
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if args.cuda:

kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
train_loader =
    datasets.MNIST('../data', train=True, download=True,
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader =
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)

K=10 # number of classes
N=30 # number of categorical distributions
tau0 = 1.0
tau = tau0
tau = Variable(torch.tensor(tau))
def sample_gumbel(shape, eps=1e-20):
    U = torch.Tensor(shape).uniform_(0,1).cuda()
    return -(torch.log(-torch.log(U + eps) + eps))

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y = (y_hard - y).detach() + y
    return y

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fce1 = nn.Linear(784, 512)
        self.fce2 = nn.Linear(512, 256)
        self.fce3 = nn.Linear(256, K*N)
        self.fcd1 = nn.Linear(K*N, 256)
        self.fcd2 = nn.Linear(256, 512)
        self.fcd3 = nn.Linear(512, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(1)

    def encode(self, x):
        he1 = self.relu(self.fce1(x))
        he2 = self.relu(self.fce2(he1))
        he3 = self.fce3(he2)
        logits_y = he3.view(-1, K)
        qy = self.softmax(logits_y)
        log_qy = torch.log(qy + 1e-20)
        return logits_y, log_qy, qy

    def reparameterize(self, mu, logvar):
            std = logvar.mul(0.5).exp_()
            eps = Variable(
            return eps.mul(std).add_(mu)
            return mu

    def decode(self, z):
        # sample and reshape back (shape=(batch_size,N,K))
        # set hard=True for ST Gumbel-Softmax
        ge = gumbel_softmax(z, tau, hard=False).view(-1, N, K)
        hd1 = self.relu(self.fcd1(ge.view(-1, N*K)))
        hd2 = self.relu(self.fcd2(hd1))
        hd3 = self.fcd3(hd2)
        return hd3

    def forward(self, x):
        logits_y, log_qy, qy = self.encode(x.view(-1, 784))
        return self.decode(logits_y),log_qy, qy

model = VAE()
if args.cuda:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, log_qy, qy, data):
    sigmoid = nn.Sigmoid()
    kl_tmp = (qy * (log_qy - torch.log(torch.tensor(1.0 / K)))).view(-1, N, K)
    KL = torch.sum(torch.sum(kl_tmp, 2),1)
    shape = data.size()
    #elbo = torch.sum(recon_x.log_prob(data.view(shape[0], shape[1] * shape[2] * shape[3])), 1) - KL
    data_ = data.view(shape[0], shape[1] * shape[2] * shape[3])
    # calculate binary cross entropy using explicit calculation rather than using pytorch distribution API
    bce = torch.sum(data_ * torch.log(sigmoid(recon_x)) + (1 - data_) * torch.log(1 - sigmoid(recon_x)), 1)
    elbo = bce - KL
    return torch.mean(-elbo), torch.mean(bce), torch.mean(KL)


def train(epoch):
    train_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if args.cuda:
            data = data.cuda()
        px, log_qy, qy = model(data)
        recon_x = torch.distributions.bernoulli.Bernoulli(logits=px)
        #loss = loss_function(recon_x, log_qy, qy, data)
        loss, bce, KL = loss_function(px, log_qy, qy, data)
        train_loss +=[0]
        #if batch_idx % 1000 == 1:
        #    tau = Variable(torch.tensor(np.maximum(tau0 * np.exp(-ANNEAL_RATE * batch_idx), MIN_TEMP)))

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \tBCE: {:.6f} \tKL: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

    M = 100 * N
    np_y = np.zeros((M, K))
    np_y[range(M), np.random.choice(K, M)] = 1
    np_y = np.reshape(np_y, [100, N, K])

    px = model.decode(Variable(torch.tensor(np_y).cuda()))
    recon_x = torch.nn.Sigmoid()(px).detach().cpu().numpy()
    #recon_x = torch.distributions.Bernoulli(logits=px).sample()
    np_x = recon_x.reshape((10, 10, 28, 28))
    # split into 10 (1,10,28,28) images, concat along columns -> 1,10,28,280
    np_x = np.concatenate(np.split(np_x, 10, axis=0), axis=3)
    # split into 10 (1,1,28,280) images, concat along rows -> 1,1,280,280
    np_x = np.concatenate(np.split(np_x, 10, axis=1), axis=2)
    x_img = np.squeeze(np_x)
    plt.imshow(x_img,, interpolation='none')

args.epochs = 1
for epoch in range(1, args.epochs + 1):

Convergence (My Code):
|Train Epoch: 1 [0/60000 (0%)]|Loss: 542.632080 |BCE: -542.598450 |KL: 0.033671|
|Train Epoch: 1 [1000/60000 (2%)]|Loss: 273.803497 |BCE: -264.880951 |KL: 8.922541|
|Train Epoch: 1 [2000/60000 (3%)]|Loss: 213.316895 |BCE: -213.277115 |KL: 0.039777|
|Train Epoch: 1 [3000/60000 (5%)]|Loss: 215.961517 |BCE: -215.944489 |KL: 0.017020|
|Train Epoch: 1 [4000/60000 (7%)]|Loss: 208.288528 |BCE: -208.240768 |KL: 0.047799|
|Train Epoch: 1 [5000/60000 (8%)]|Loss: 203.074173 |BCE: -203.020462 |KL: 0.053740|

Convergence (TensorFlow)
Step 1, ELBO: -544.939, KL: 0.540, BCE: -544.399
Step 1001, ELBO: -136.944, KL: 11.626, BCE: -125.318
Step 2001, ELBO: -121.850, KL: 13.688, BCE: -108.161
Step 3001, ELBO: -108.513, KL: 14.693, BCE: -93.820
Step 4001, ELBO: -110.518, KL: 15.602, BCE: -94.916

Could you paste reformatted code? It is a headache for me to re-arrange your code.
Have a look at this implementation.