Hi all,
I’m relatively new to PyTorch and have been having an issue with performing back-propagation that I just don’t know how to solve.
First of all, my code for calculating the loss and then performing back-propagation:
def discriminator_loss(self, real_samples, fake_samples):
criterion = BCEWithLogitsLoss()
real_loss = criterion(self.discriminator(real_samples), ones((1, 1)).to(self.gpu_id))
fake_loss = criterion(self.discriminator(fake_samples), zeros((1, 1)).to(self.gpu_id))
return (real_loss + fake_loss) / 2
...
def optimise_discriminator(self, input, targets):
# Load input and targets into GPU memory.
input = input.cuda()
for target in targets:
target.cuda()
# Pass input through the generator.
outputs = self.generator.forward(input)
# Get the loss.
discriminator_loss = discriminator_loss(targets, outputs)
# Optimise the discriminator.
self.optimiser_discriminator.zero_grad()
discriminator_loss.backward()
self.optimiser_discriminator.step()
# Return the loss.
return discriminator_loss.item()
The problem occurs when I call the backward method on the output of discriminator_loss function. I get the following error:
Traceback (most recent call last):
File "train.py", line 84, in <module>
main()
File "train.py", line 74, in main
generator_loss = model.optimise_discriminator(input, targets)
File "/nobackup/sccmho/projects/DBT2FFDM/models/dbt2ffdm_model.py", line 101, in optimise_discriminator
discriminator_loss.backward()
File "/nobackup/sccmho/.conda/envs/pix2pix/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/nobackup/sccmho/.conda/envs/pix2pix/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
File "/nobackup/sccmho/.conda/envs/pix2pix/lib/python3.7/site-packages/torch/autograd/function.py", line 77, in apply
return self._forward_cls.backward(self, *args)
File "/nobackup/sccmho/.conda/envs/pix2pix/lib/python3.7/site-packages/torch/autograd/function.py", line 181, in backward
raise NotImplementedError
This occurs when only returning real_loss or fake_loss in the discriminator_loss function as well. I’m at a loss () as to what is causing this, as I’m not using a custom loss function. Any guidance would be appreciated.