Hi I want to copy weights of a network to another network and freeze that weights and the second network just updates the layers that differs from first networks(second network has some layers exactly the same as first network and some different layers and I want to copy the weights of first network in similar layers and freeze them and then train the second network to find different layers’ weights).
copy:
for name, param in model2.named_parameters():
if param.requires_grad and 'conv1' in name:
param.requires_grad = False
if param.requires_grad and 'conv2' in name:
param.requires_grad = False
if param.requires_grad and 'conv3' in name:
param.requires_grad = False
Hi @matheo_r thanks for reply… actually I don’t know ‘element 0 of tensors’ refers to what!!!
if it means the first layer of my network, its the freezed one…
this is my network:
conv1 freezed
conv2 freezed
conv3 freezed
conv4 trainable
conv5 trainable
conv6 trainable
fully connected freezed
fully connected freezed
Hi @InnovArul@matheo_r … thanks a lottttttttt for your reply…
My code is sth like bellow…
import torch, torch.nn as nn
import torch.nn.functional as F
import os, sys
import copy
import torch.optim as optim
class ConvReLU(nn.Module):
def __init__(self, indim, outdim):
super().__init__()
self.conv = nn.Conv2d(indim, outdim, kernel_size=1)
def forward(self, x):
return F.relu(self.conv(x))
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = ConvReLU(3,6)
self.conv2 = ConvReLU(6,12)
self.conv3 = ConvReLU(12,18)
self.conv4 = ConvReLU(18,18)
self.conv5 = ConvReLU(18,18)
self.conv6 = ConvReLU(18,18)
self.fc1 = nn.Linear(25*18, 100)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
y = copy.deepcopy(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = F.relu(self.fc1(x.view(x.shape[0], -1)))
x = self.fc2(x)
return x
def freeze_params(model):
for name, param in model2.named_parameters():
if param.requires_grad and 'conv1' in name:
param.requires_grad = False
if param.requires_grad and 'conv2' in name:
param.requires_grad = False
if param.requires_grad and 'conv3' in name:
param.requires_grad = False
if param.requires_grad and 'fc1' in name:
param.requires_grad = False
if param.requires_grad and 'fc2' in name:
param.requires_grad = False
def copy_weights(src, dst):
dst.conv1=copy.deepcopy(src.conv1)
dst.conv2=copy.deepcopy(src.conv2)
dst.conv3=copy.deepcopy(src.conv3)
def my_loss(x,y):
loss =torch.norm(x-y,2)
return loss
if __name__ == "__main__":
data = torch.randn(2,3,5,5).cuda()
y = torch.randn(2,3,5,5).cuda()
model1 = CNN().cuda()
model2 = CNN().cuda()
copy_weights(model1, model2)
assert model1.conv1.conv.weight.data_ptr() != model2.conv1.conv.weight.data_ptr()
assert model1.conv2.conv.weight.data_ptr() != model2.conv2.conv.weight.data_ptr()
assert model1.conv3.conv.weight.data_ptr() != model2.conv3.conv.weight.data_ptr()
freeze_params(model2)
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model2.parameters()), lr=0.01)
opt.zero_grad()
out = model2(data)
loss = my_loss(data,y)
loss.backward()
opt.step()
and the error is:
RuntimeError Traceback (most recent call last)
<ipython-input-12-c82715d18b62> in <module>()
77 out = model2(data)
78 loss = my_loss(data,y)
---> 79 loss.backward()
80 opt.step()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
thanks @raufbhat-dev … I changed the last fc layer require_grad to True and the error fixed… I dont know what is the difference between the code I post here an the actual code but this loss = my_loss(out,y) is correct in that and the problem was because of the last layer…
Is there any way to set last layer’s require_grad =False and do not face that problem?
Interestingly, I am unable to reproduce this as well. i.e., Even if I set the last layer’s requires_grad = False, there is no error. You can find this in the same code that I shared earlier.
Yes the code that you sent, is correct. But the code that I wrote, still have faced that error. I change the require_grad of last layer and it fixed… @InnovArul Could you please tell the purpose of this part?