Hi, all.
I want to freeze a model. I have two questions about this.
Is this the right way to freeze?
class Network(nn.Module):
...
class NetworkToFreeze(nn.Module):
...
for p in network.parameters():
p.requires_grad = True
for p in network_to_freeze.parameters():
p.requires_grad = True
...
for epoch in train_process:
if epoch < 50:
train all network code(Same code here*)
if epoch >= 50:
for p in network_to_freeze.parameters():
p.requires_grad = False
train all network code(Same code here*)
######################################################
2. Or just not applying loss value to the backward() is ok ?
if epoch < 50:
train all network code(Same code here*)
total_loss = network_loss + network_to_freeze_loss
total_loss.backward()
if epoch >=50:
train all network code(Same code here*)
total_loss = network_loss
total_loss.backward()
It occurs this error. RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
You can do that… but it’s little bit strange to split the network in two parts.
You can just run
for p in network.parameters():
p.requires_grad = True
and use an if statement inside that for which filters those layer which you want to freeze.
if freeze p.requires_grad = False else p.requires_grad = True
(Note that the else condition is not strictly necessary if you defined them as True at the beggining)
@JuanFMontesinos
Here is my whole train process.
Can you check last?
class Network(nn.Module):
...
class NetworkToFreeze(nn.Module):
...
for p in network.parameters():
p.requires_grad = True
for p in network_to_freeze.parameters():
p.requires_grad = True
...
for epoch in train_process:
if epoch < 50:
train all network code(Same code here*)
total_loss = network_loss + network_to_freeze_loss
total_loss.backward(retain_graph=True)
else:
train all network code(Same code here*)
for p in network_to_freeze.parameters():
p.requires_grad = False
total_loss = network_loss
total_loss.backward(retain_graph=True)