Reconstructing vectors with different ranges using a VAE

I have a perplexity about my variational autoencoder model. I’m trying to reconstruct 1d vectors which have numbers greater equal than zero in the first 9 entries and numbers between 0 and 1 between last 9 and 18 entries. What I’ve been doing has been including a nn.ReLU() and nn.Softmax() in the forward function for the specified entries and I’ve done the same in the sample() method to sample appropriate instances after training (that respect the specified range). (I’ve noticed that without doing that using the simple decoding process would end up violating the range as I only have a general nn.Linear() at the end of it).

Do you think that this procedure of including different activation functions for different parts of reconstructed vectors in the forward() and sample() method is correct?

Should I instead trying to incorporate everything in the decoder?

Also, given that my input dimension is pretty low, does it make sense to still provide hidden layers of 256,512… neurons?

class CVAE(BaseModel):
    
    def __init__(self, in_size, target_size):
        
        super().__init__()
        
        #the input is concatenated to the target property
        self.encoder = nn.Sequential(
            nn.Linear(in_size + target_size, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(128,latent_size*2),
        )
 
        self.decoder = nn.Sequential(
            nn.Linear(latent_size + target_size, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(128,256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256,512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512,in_size),
        )

    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def encode(self, x,cond):
      x = torch.cat([x,cond],dim=1)
      mu_logvar = self.encoder(x).view(-1, 2, latent_size)
      mu = mu_logvar[:, 0, :]
      logvar = mu_logvar[:, 1, :]
      return mu, logvar
  

    def decode(self, z):
      return self.decoder(z)

    def forward(self,x,cond):
        
        mu, logvar = self.encode(x,cond)
        z = self.reparameterise(mu, logvar)
        z = torch.cat([z,cond],dim=1)
        x_hat = self.decode(z)
        
        part1 = nn.ReLU()(x_hat[:,:9])
        part2 = nn.Softmax(dim=1)(x_hat[:,9:18])
        
        x_hat = torch.cat([part1,part2],dim=1)
        
        return x_hat, mu, logvar

    def sample(self, n_samples, cond):
    
        #for bandgap e.g. cond= (6 - 3) * torch.rand(5, 1) + 3
        
      z = torch.randn((n_samples, latent_size)).to(torch.device)
      z = torch.cat([z,cond],dim=1)
      z = self.decode(z)
      z[:,:9] = nn.ReLU()(z[:,:9])
      z[:,9:] = nn.Softmax(dim=-1)(z[:,9:])
      z = z.detach().numpy()
      
      return z