RuntimeError: F.binary_cross_entropy all elements of input should be between 0 and 1

‘’'class Dir_VAE(nn.Module):
def init(self):
super(Dir_VAE, self).init()
self.encoder = nn.Sequential(

        nn.Conv2d(nc, ndf, 4,4, 0,bias=False),
        nn.LeakyReLU(0.2, inplace=True),
       
        nn.Conv2d(ndf, ndf * 2, 4,4, 0,bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
       
        nn.Conv2d(ndf * 2, ndf * 4, 4, 4, 0,bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
       
        nn.Conv2d(ndf * 4, 512, 4, 2,0,bias=False),
        nn.LeakyReLU(0.2, inplace=True),
      
    )

    self.decoder = nn.Sequential(
       
        nn.ConvTranspose2d(512, ngf * 4, 4, 2, 0, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.ReLU(True),
        
        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 4, 0, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
       
        nn.ConvTranspose2d(ngf * 2, ngf * 2, 4, 4, 0, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
        
        nn.ConvTranspose2d(ngf * 2, nc, 4, 4, 0, bias=False),
       
        nn.Sigmoid()
       
    )
    self.fc1 = nn.Linear(512, 256)
    self.fc21 = nn.Linear(256, 10)
    self.fc22 = nn.Linear(256, 10)

    self.fc3 = nn.Linear(10, 256)
    self.fc4 = nn.Linear(256, 512)

    self.lrelu = nn.LeakyReLU()
    self.relu = nn.ReLU()

    # Dir prior
    self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3)) # 0.3 is a hyper param of Dirichlet distribution
    self.prior_logvar = nn.Parameter(self.prior_var.log())
    self.prior_mean.requires_grad = False
    self.prior_var.requires_grad = False
    self.prior_logvar.requires_grad = False


def encode(self, x):
    conv = self.encoder(x);
    print('Size', conv.shape)
    h1 = self.fc1(conv.view(-1, 512))
    return self.fc21(h1), self.fc22(h1)

def decode(self, gauss_z):
    dir_z = F.softmax(gauss_z,dim=1) 
    h3 = self.relu(self.fc3(dir_z))
    deconv_input = self.fc4(h3)
    print('Deconv ', deconv_input.shape)
    deconv_input = deconv_input.view(-1,512,1,1)
    return self.decoder(deconv_input)

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std


def forward(self, x):
    mu, logvar = self.encode(x)
    gauss_z = self.reparameterize(mu, logvar) 
    # gause_z is a variable that follows a multivariate normal distribution
    # Inputting gause_z into softmax func yields a random variable that follows a Dirichlet distribution (Softmax func are used in decoder)
    dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
    return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(self, recon_x, x, mu, logvar, K):
    beta =0.9
    print('Recon ',recon_x.shape)
    print('Data ' ,x.shape)              
    BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
   
    
    KLD = -0.5 * torch.sum(1+logvar - mu**2 - torch.exp(logvar), axis=1)
    return BCE + beta*KLD```
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-ef9ef8984921> in <module>
    267     # 学習(Train)
    268     for epoch in range(50):
--> 269         train(epoch)
    270         test(epoch)
    271         with torch.no_grad():

<ipython-input-8-ef9ef8984921> in train(epoch)
    226         optimizer.zero_grad()
    227         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
--> 228         loss = model.loss_function(recon_batch, data, mu, logvar, 10)
    229         loss = loss.mean()
    230         loss.backward()

<ipython-input-8-ef9ef8984921> in loss_function(self, recon_x, x, mu, logvar, K)
    198         print('Recon ',recon_x.shape)
    199         print('Data ' ,x.shape)
--> 200         BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
    201         # ディリクレ事前分布と変分事後分布とのKLを計算
    202         # Calculating KL with Dirichlet prior and variational posterior distributions

C:\Conda5\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2760         weight = weight.expand(new_size)
   2761 
-> 2762     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
   2763 
   2764 

RuntimeError: all elements of input should be between 0 and 1```'''

Hi Hillary!

Look for nans (or infs) in recon_batch. It’s not entirely clear how
your code hooks together, and you don’t tell us what model is, but
the error message is complaining about input, the first argument to
binary_cross_entropy(), and nans would be the most likely cause.

Given what I think your code is doing, you shouldn’t be getting values
less than zero or greater than one in recon_batch, but you should
check for values outside of [0.0, 1.0], as well.

As an aside, for reasons of numerical stability, you should get rid
of the final Sigmoid “layer” in self.decoder (so that it outputs
logits rather than probabilities) and use
torch.nn.functional.binary_cross_entropy_with_logits()
(instead of binary_cross_entropy()) for that piece of your
Kulback-Leibler-divergence loss_function().

Best.

K. Frank