ValueError: too many values to unpack (expected 2)- conditional prior VAE

I try to implement conditional VAE with toy data set but I got this error. Here is my code

def gen_batch(BATCH_SIZE):
    theta = (np.pi/4) * torch.randint(0, 8, (BATCH_SIZE,)).float().to(device)
    centers = torch.stack((torch.cos(theta), torch.sin(theta)), dim = -1)
    noise = torch.randn_like(centers) * 0.1

    return centers + noise

def data_gen(BATCH_SIZE):
    #8 gaussians
    while 1:
        yield gen_batch(BATCH_SIZE)
                
test_loader = train_loader = data_gen(args.batch_size)

def train(epoch):
    CVAE.train()
    train_loss = 0
   
    for batch_idx, (data,cond) in enumerate(train_loader):
        if batch_idx > 100:
            break #100 batches per epoch
       # data = data.to(device)
        data, cond = data.cuda(), one_hot(cond, cond_dim).cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = CVAE(data, cond)
        print("mu", mu)
        print("logvar", logvar)
        loss = loss_function(recon_batch, data, mu, logvar)
        recon_batch = recon_batch.detach().numpy()
ValueError                                Traceback (most recent call last)
<ipython-input-93-9983f3e978b5> in <module>
      1 for epoch in range(1, 51):
----> 2     train(epoch)

<ipython-input-92-aa690337c5d3> in train(epoch)
      3     train_loss = 0
      4 
----> 5     for batch_idx, (data,cond) in enumerate(train_loader):
      6         if batch_idx > 100:
      7             break #100 batches per epoch

ValueError: too many values to unpack (expected 2)

I would be really grateful if anyone could help me

It seems you are expecting two values (data, cond) from data_gen().
But it seems to return a tensor.

When python tries to assign two values to data, cond variables, it seems there are excess elements in the tensor received. Hence the error too many values to unpack, I think.

Thank you, actually, I tried to implement conditional VAE based on MNIST data set and I know that is different from my dataset but now I don’t know how to fix it and how just return two values?