Above is the result of my autoencoder after 50 epochs of training… ridiculous I know. I have trained autoencoders before and so I am a little flummoxed by my inability to find the issue here. It is a bit of a weird model but not that weird (though my code is a mess because I am experimenting with multiple ideas). I think the loss isn’t backpropagating properly since the worst-case scenario with a working loss is usually a gray image, not this multicolored noise.
Here are some problems I am considering:
- I am training on TPUs so perhaps I am missing some PyTorch xla thing.
- I have two outputs and I hold the 2 output tensors in a list, which is then held in a dictionary before being used in the loss function. I have done this before so I don’t think that should be an issue.
- I am using nn.upsample which I have not used before, but I think the loss should be able to go through that no problem.
- Gradient collapse??? I am using leakyrelu and the loss is never nan.
- I have tested with BCE and MSE loss, both result in the multicolored mess.
here is my training loop… pretty standard except that I use a lot of dictionaries.
for image_batch, labels in data_loader:
# for step, (image_batch, labels) in enumerate(para_loader):
image_batch = image_batch.to(device)
labels = labels.to(device).float()
# vae reconstruction
outputDict = model(image_batch, outputDict)
# reconstruction error
lossDict = VACTS_loss(image_batch, labels, run, outputDict, lossDict)
#loss propogation through each VCap
lossy = lossDict["loss"]
# backpropagation
optimizer.zero_grad()
lossy.backward(retain_graph=True)
# one step of the optmizer (using the gradients from backpropagation)
# optimizer.step()
xm.optimizer_step(optimizer, barrier=True)
if verbose == 2: print("|", end = "")
my loss is (simplified for readability):
if run.daloss == "BCE":
glb_recon_loss = nn.BCEWithLogitsLoss(reduction='mean')
if run.daloss == "MSE":
glb_recon_loss = nn.MSELoss(reduction='mean') ## add sigmoid for this loss
glb_recon_loss_total = 0
for varActor in range(run.VC_num):
# Global Recon Loss
if haveGlbReconLoss:
glb_recon_loss_part = glb_recon_loss(recon_x_list[varActor], image_batch)
glb_recon_loss_total += glb_recon_loss_part
loss = lambda4 * glb_recon_loss_total #lambda4 =1
lossDict["loss"] = loss
return lossDict
Any ideas of what this issue maybe?? Please and thankyou!