I’m collecting two different measures of loss and propagating both of them to a single network. During the loss.backward() to the generator, I get this error about an invalid gradient at index 0.
Here is the code block:
for epoch in range(num_epochs):
for batch_idx, (real, labels) in enumerate(loader):
#get a fixed input batch to display gen output
if batch_idx == 0:
if epoch == 0:
fixed_input = real.view(-1,784).to(device)
adv_ex = real.clone().reshape(-1,784).to(device) # [32, 784] advex copy of first batch flattened
real = real.view(-1, 784).to(device) # [32, 784] # real batch flattened
labels = labels.to(device) # size() [32] 32 labels in batch
#purturb each image in adv_ex
tmp_adv_ex = []
for idx, item in enumerate(adv_ex):
purturbation = gen(adv_ex[idx])
tmp_adv_ex.append(adv_ex[idx] + purturbation)
adv_ex = torch.cat(tmp_adv_ex, dim=0)
### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
output = disc(adv_ex).view(-1)
lossG = torch.mean(torch.log(1. - output)) #get loss for gen's desired desc pred
adv_ex = adv_ex.reshape(-1,1,28,28)
f_pred = target(adv_ex)
f_loss = CE_loss(f_pred, labels) #add loss for gens desired f pred
loss_G_Final = f_loss+lossG
opt_gen.zero_grad()
loss_G_Final.backward() #THIS IS THE ERROR SOURCE
opt_gen.step()
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
adv_ex = adv_ex.reshape(32, 784)
disc_real = disc(real).view(-1)
disc_fake = disc(adv_ex).view(-1)
lossD = -torch.mean(torch.log(disc(real)) + torch.log(1. - disc(adv_ex)))
opt_disc.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()
and this is the exact error code:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_31570/4032279268.py in <module>
30
31 opt_gen.zero_grad()
---> 32 loss_G_Final.backward()
33 opt_gen.step()
34
~/.conda/envs/mypytorch19/lib/python3.9/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
253 create_graph=create_graph,
254 inputs=inputs)
--> 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
256
257 def register_hook(self, hook):
~/.conda/envs/mypytorch19/lib/python3.9/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
145 retain_graph = create_graph
146
--> 147 Variable._execution_engine.run_backward(
148 tensors, grad_tensors_, retain_graph, create_graph, inputs,
149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [1, 784] but expected shape compatible with [1, 25088]
How can I interpret and solve this error?
Thanks!