I was looking at a Keras cycleGan and they use the following parameters to deactivate training in the combined model of generator1 -> discriminator 1 -> generator ->2

def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
# ensure the model we're updating is trainable
g_model_1.trainable = True
# mark discriminator as not trainable
d_model.trainable = False
# mark other generator model as not trainable
g_model_2.trainable = False
# discriminator element

Would the equilant be:

g_model_2=g_model_2.train(mode=false)

where g_model_2 is defined by a class? if so why do i get the error message???

Most likely model.train() is not corresponding to the trainable argument in Keras. model.train() and model.eval() switches the behavior of some modules, e.g. it’ll disable dropout during eval and use the running stats in batchnorm layers.
If you want to freeze parameters, you can set param.requires_grad=False.
However, note that optimizers with running estimates or momentum can still update these parameters even if the gradient is zero.

with torch.no_grad() would work, if you don’t need any gradients in parameters, which are used before this block, since the no_grad() guard would not create the computation graph for this section.

Here is a small example:

# setup
lin1 = nn.Linear(1, 1)
lin2 = nn.Linear(1, 1)
lin3 = nn.Linear(1, 1)
x = torch.randn(1, 1)
# case1
with torch.no_grad():
out = lin1(x)
out = lin2(out)
out = lin3(out)
out.mean().backward()
print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad)
> None None tensor([[-0.8122]])
lin1.zero_grad(), lin2.zero_grad(), lin3.zero_grad()
# case2
out = lin1(x)
with torch.no_grad():
out = lin2(out)
out = lin3(out)
out.mean().backward()
print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad) # !!! lin1 doesn't have grads !!!
> None None tensor([[-0.8122]])
lin1.zero_grad(), lin2.zero_grad(), lin3.zero_grad()
# case3
for param in lin2.parameters():
param.requires_grad = False
out = lin1(x)
out = lin2(out)
out = lin3(out)
out.mean().backward()
print(lin1.weight.grad, lin2.weight.grad, lin3.weight.grad)
> tensor([[-0.0768]]) None tensor([[-0.8122]])