CycleGan - Turn off training

morning all,

I was looking at a Keras cycleGan and they use the following parameters to deactivate training in the combined model of generator1 -> discriminator 1 -> generator ->2

def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
	# ensure the model we're updating is trainable
	g_model_1.trainable = True
	# mark discriminator as not trainable
	d_model.trainable = False
	# mark other generator model as not trainable
	g_model_2.trainable = False
	# discriminator element

Would the equilant be:

g_model_2=g_model_2.train(mode=false)

where g_model_2 is defined by a class? if so why do i get the error message???

TypeError: train() missing 1 required positional argument: ‘self’

thanks in advance.

Chaslie

Most likely model.train() is not corresponding to the trainable argument in Keras.
model.train() and model.eval() switches the behavior of some modules, e.g. it’ll disable dropout during eval and use the running stats in batchnorm layers.
If you want to freeze parameters, you can set param.requires_grad=False.
However, note that optimizers with running estimates or momentum can still update these parameters even if the gradient is zero.

Hi Ptrblck,

would

with torch.no_grad():

for that particuler training block work?

with torch.no_grad() would work, if you don’t need any gradients in parameters, which are used before this block, since the no_grad() guard would not create the computation graph for this section.

Here is a small example:

# setup
lin1 = nn.Linear(1, 1)
lin2 = nn.Linear(1, 1)
lin3 = nn.Linear(1, 1)

x = torch.randn(1, 1)

# case1
with torch.no_grad():
    out = lin1(x)
    out = lin2(out)
out = lin3(out)
out.mean().backward()

print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad)
> None None tensor([[-0.8122]])
lin1.zero_grad(), lin2.zero_grad(), lin3.zero_grad()

# case2
out = lin1(x)
with torch.no_grad():
    out = lin2(out)
out = lin3(out)
out.mean().backward()

print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad) # !!! lin1 doesn't have grads !!!
> None None tensor([[-0.8122]])
lin1.zero_grad(), lin2.zero_grad(), lin3.zero_grad()

# case3
for param in lin2.parameters():
    param.requires_grad = False

out = lin1(x)
out = lin2(out)
out = lin3(out)
out.mean().backward()
print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad)
> tensor([[-0.0768]]) None tensor([[-0.8122]])