Exploding loss and gradients for the VAE

I am trying to apply the Variational Auto-Encoder (see https://arxiv.org/abs/1312.6114) to the MNIST data set. I am following exactly the paper with the Neural Network architecture and all other parameters and formulas. Unfortunately, after one step, I get an explosion of the loss (the loss is the negative ELBO term here) and an explosion of the gradients. A change of the learning rate, mini batch size and network structures did not help to fix the problem. I wonder what I can further do to avoid this behaviour. Could you help me? I really appreciate every support.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import *
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import sys

#Load the data:
dl = DataLoader(torchvision.datasets.MNIST('/data/mnist', train=True, download=True))
tensor = dl.dataset.data
tensor = tensor.to(dtype=torch.float32)
tr = tensor.reshape(tensor.size(0), -1) 
tr = tr/255
targets = dl.dataset.targets
targets = targets.to(dtype=torch.long)

#This will be the data which I use:
Traindata = tr[0:50000]
Labels = targets[0:50000]
#For binary VAEs, we use only binary data:
BTraindata= torch.round(Traindata)
#Define the global variables:
minibatch_size=100
latent_dim=2
dim_inp_dec=latent_dim
dim_hidden_dec=100
dim_out_dec=28*28
dim_inp_enc=28*28
dim_hidden_enc=100
train_size=Traindata.size()[0]
#Dataloader:
trainloader=torch.utils.data.DataLoader(Traindata, batch_size=minibatch_size, shuffle=True)
#This is the Variational Auto Encoder class in the case of binary data:
class BVAE(nn.Module):
    def __init__(self,latent_dim):
        super(BVAE, self).__init__()
        #Define the decoder structural parts:
        self.dec1=nn.Linear(dim_inp_dec,dim_hidden_dec,bias=True)
        self.dec2=nn.Linear(dim_hidden_dec,dim_out_dec,bias=True)
        self.decactiv1=nn.Tanh()
        self.decactiv2=nn.Sigmoid()
        
        #Define the encoder structural parts:
        self.enc1=nn.Linear(dim_inp_enc,dim_hidden_enc,bias=True)
        self.enc21=nn.Linear(dim_hidden_enc,latent_dim,bias=True)
        self.enc22=nn.Linear(dim_hidden_enc,latent_dim,bias=True)
        self.encactiv=nn.Tanh()    
        
        self.latent_dim=latent_dim
    
    def encode(self,x):
        mean=self.enc21(self.encactiv(self.enc1(x)))
        log_var=self.enc22(self.encactiv(self.enc1(x)))
        return mean,log_var
    
    def decode(self,x):
        x_one=self.dec1(x)
        dec_hidden=self.decactiv1(x_one)
        output=self.decactiv2(self.dec2(dec_hidden))
        return output
    
    def forward(self,x,normsamples):
        mean,log_var=self.encode(x)
        var=torch.exp(log_var)
        sample=mean+torch.exp(log_var)*normsamples
        return(self.decode(sample))
   
    #Instead of sampling after the encoder network, we can also replace it by its mean (i.e. ignoring the variance):
    def detforward(self,x):
        mean=self.encode(x)[0]
        return(self.decode(mean))
    
    #The loss of a VAE is the negative ELBO term:
    def loss(self,mini_vector,latent_samples):
        L=latent_samples.size()[1]
        M=mini_vector.size()[0]
        if (M!=latent_samples.size()[0]): 
            sys.exit("Sample vector size does not fit mini_vector size")
        
        ELBO=torch.tensor(0.0,dtype=torch.float)
        for i in range(M):
            mean,log_var = self.encode(mini_vector[i])
            Entropy=0.5*torch.sum(torch.ones(latent_dim)+log_var-torch.pow(mean,2)-torch.exp(log_var))
            sum=torch.tensor(0.0,dtype=torch.float)
            for l in range(L):
                Bern_par=self.decode(mean+torch.sqrt(torch.exp(log_var))*latent_samples[i][l])
                sum+=torch.sum(mini_vector[i]*torch.log(Bern_par)+(1-mini_vector[i])*torch.log(1-Bern_par))
            #The second term of the Lower bound is:
            Recon_term=sum/L
            #Add the two terms:
            ELBO+=Entropy+Recon_term
        ELBO=(train_size/M)*ELBO
        return (-ELBO)
    
    #We now define the train function. nsamples is the number of samples per data point (nsamples_per_point):
    def train(self,data_loader,learning_rate=0.001,epochs=10,minibatch_size=100,nsamples=1):
        for i,data in enumerate(data_loader):
                samples=torch.randn(minibatch_size,nsamples,self.latent_dim)
                print("-------------Epoch ", i,"---------------")
                #Set gradient to zero:
                self.zero_grad()
                #Compute loss:
                loss = self.loss(data,samples) 
                print("Epoch: ",i," Loss: ",loss.item())
                loss.backward()
                for p in self.parameters():
                    with torch.no_grad():
                        print("Norm of gradient:", torch.sum(p.grad.data**2))
                        p.data-=learning_rate*p.grad.data
                        p.grad.zero_()



            
        
autoencoder=BVAE(latent_dim)
autoencoder.train(trainloader)

This is the output I get:
-------------Epoch 0 ---------------
Epoch: 0 Loss: 27568040.0
Norm of gradient: tensor(8.4638e+09)
Norm of gradient: tensor(6.8262e+10)
Norm of gradient: tensor(3.3651e+12)
Norm of gradient: tensor(3.2536e+11)
Norm of gradient: tensor(3.1535e+11)
Norm of gradient: tensor(8.2232e+09)
Norm of gradient: tensor(2.7858e+10)
Norm of gradient: tensor(1.4739e+10)
Norm of gradient: tensor(1.4373e+10)
Norm of gradient: tensor(8.9952e+09)
-------------Epoch 1 ---------------
Epoch: 1 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
-------------Epoch 2 ---------------
Epoch: 2 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
-------------Epoch 3 ---------------
Epoch: 3 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)

Could you check your loss implementation with this example? Since the gradient norms are that high, I would assume your loss blows up.