Replace loss value using torch.no_grad?

Hi, I have a Generator network that learns matching relationships (e.g., getting index pairs based on matching scores) between two sets, and then I have a Discriminator network judging if the matching quality is good.

The problem is that the discriminator takes matched data as input, and getting index pairs is a non-differentiable operation and will break the graph, so I cannot use the Discriminator loss to update Generator.

One workaround I am considering is that, I can calculate the mean value of the matching scores, which is still differentiable, replace this value with the Discriminator loss value under with.no_grad(), and then do loss.backward() to update the Generator.

So this is more like setting a fake loss function based on the matching scores and then replacing the loss value with the actual loss value I want to use (calculated by Discriminator). My understanding is that it will use the gradient computation graph from the matching scores to update the Generator.

Would you think this idea will work? Are there any misunderstandings or issues I am not aware of? Any suggestions or comments would be greatly appreciated!

Hello, can someone provide some comments here? Thanks.

I don’t think your idea sounds valid as manipulating the loss value in a no_grad() context should either raise an error or should not change the gradient calculation at all.

@ptrblck Thanks so much for the reply! Do you think there is a way to modify a loss value without breaking the gradient graph (such as detaching it and changing the value)? I am currently doing something like this:

# update G
generator.zero_grad()

# get matching scores (differentiable)
scores = generator(data, noise)

# get indices from scores (non-differentiable)
indices = scores.max(1).indices

# permutate the data and calculate loss via discriminator 
d_output = discriminator(data[indices])
d_loss = criterion(d_output, d_label)

# make fake loss using scores.mean() and change its value via detaching
fake_g_loss = scores.mean()
fake_g_loss_detached = fake_g_loss.detach()
fake_g_loss_detached *= 0.
fake_g_loss_detached += d_loss.item()

# Calculate gradients for G
fake_g_loss_detached.backward()
  
# Update G
optimizerG.step()

I checked fake_g_loss_detached and it seems the value was updated (replaced with d_loss value) and the computation graph was not affected. I also trained the model and it is running. But I am not sure if this really works. Would such operations make sense to you?

I’m currently unsure why your code works at all, since fake_g_loss_detached is indeed detached (as the name indicates) and the backward() call should fail as seen here:

generator = nn.Linear(10, 10)
data = torch.randn(1, 10)

discriminator = nn.Linear(10, 10)
criterion = nn.MSELoss()

# get matching scores (differentiable)
scores = generator(data)

# get indices from scores (non-differentiable)
indices = scores.max(1).indices

# permutate the data and calculate loss via discriminator 
d_output = discriminator(data)
d_loss = criterion(d_output, torch.rand_like(d_output))

# make fake loss using scores.mean() and change its value via detaching
fake_g_loss = scores.mean()
fake_g_loss_detached = fake_g_loss.detach()
fake_g_loss_detached *= 0.
fake_g_loss_detached += d_loss.item()

# Calculate gradients for G
fake_g_loss_detached.backward() # this should fail !!!
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@ptrblck Oh sorry I made a mistake. It was actually fake_g_loss.backward(). I tried to simplify the code. The original code has more content. Would fake_g_loss.backward() update the network with the new value or will it still use the old one?

Ah OK, this makes sense.
fake_g_loss.backward() will use its own loss value to calculate the gradients. You could seed the code to make it reproducible and check the gradients after the backward call e.g. via:

print(generator.weight.grad.abs().sum())

Then you could play around with scaling fake_g_loss_detached and would see that no value changes the computed gradients.

I’m not sure, if this would fit your use case, but in case you want to use a different code path in the backward pass, you could check e.g. this approach described by @tom.

Ah thanks so much for the verification! Yes I tried a bit and it looks like it only uses its own value to calculate gradients. Also thanks for sharing with me the approach. It looks promising and I will give it a try.