I try to implement conditional VAE with toy data set but I got this error. Here is my code
def gen_batch(BATCH_SIZE):
theta = (np.pi/4) * torch.randint(0, 8, (BATCH_SIZE,)).float().to(device)
centers = torch.stack((torch.cos(theta), torch.sin(theta)), dim = -1)
noise = torch.randn_like(centers) * 0.1
return centers + noise
def data_gen(BATCH_SIZE):
#8 gaussians
while 1:
yield gen_batch(BATCH_SIZE)
test_loader = train_loader = data_gen(args.batch_size)
def train(epoch):
CVAE.train()
train_loss = 0
for batch_idx, (data,cond) in enumerate(train_loader):
if batch_idx > 100:
break #100 batches per epoch
# data = data.to(device)
data, cond = data.cuda(), one_hot(cond, cond_dim).cuda()
optimizer.zero_grad()
recon_batch, mu, logvar = CVAE(data, cond)
print("mu", mu)
print("logvar", logvar)
loss = loss_function(recon_batch, data, mu, logvar)
recon_batch = recon_batch.detach().numpy()
ValueError Traceback (most recent call last)
<ipython-input-93-9983f3e978b5> in <module>
1 for epoch in range(1, 51):
----> 2 train(epoch)
<ipython-input-92-aa690337c5d3> in train(epoch)
3 train_loss = 0
4
----> 5 for batch_idx, (data,cond) in enumerate(train_loader):
6 if batch_idx > 100:
7 break #100 batches per epoch
ValueError: too many values to unpack (expected 2)
I would be really grateful if anyone could help me