How can I update on partial weight vector occasionally?

I have a combined network with generator and discriminator.
The joint training of this network collapses due to the over-fitting of discriminator.

So I try to update discriminator occasionally,
i.e., generator will be updated in every time step but in case of discriminator it will update only once in a hundred step.

The pseudo-code of the model is

class custom_loss(nn.Module):
def __ init__():

def forward(self, desired, predicted)
loss_discriminator = nn.crossEntropyLoss(desired, predicted)
loss_generator = nn.MSELoss(desired, predicted)
return aloss_discriminator + bloss_generator

predicted = model( input )
loss = custom_loss( desired, predicted )
loss.backward()

optimizer.step()

(where a, b are constant)

1 Like

Instead of having one loss function and one optimizer for both your discriminator and generator, you could have a loss and optimizer for each of them. This means your training loop would look something like this:

for i in range(NUM_ITERATIONS):
    loss_gen = ...
    loss_gen.backward()
    optim_gen.step()

    if i % 100 == 0:
        loss_disc = ...
        loss_disc.backward()
        optim_disc.step()

Hope this helps!

1 Like

Oh…i forget the way of multiple optimizer and loss objects. Now it is solved. thanks!

1 Like