Hi,
I’m trying to train a simple 2-layer convolutional network to classify Cifar-10 images where layers are trained separately.
For this purpose, I have defined 3 different nn.Module
s. Two for the convolutional layers, and one for the classifier at the end:
class Conv1(nn.Module):
def __init__(self):
super(Conv1, self).__init__()
self.conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
# self.z1_h = torch.nn.Parameter(Variable(torch.randn(64, 32, 16, 16, device=device)),
# requires_grad=True)
def forward(self, x):
x1 = self.conv(x)
z1 = F.relu(x1)
return x1, z1
class Conv2(nn.Module):
def __init__(self):
super(Conv2, self).__init__()
self.conv = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
self.z1_h = torch.nn.Parameter(Variable(torch.randn(64, 32, 16, 16, device=device)),
requires_grad=True)
def forward(self, x=None):
if x is not None:
x2 = self.conv(x)
else:
x2 = self.conv(self.z1_h)
z2 = F.relu(x2)
return x2, z2
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.fc1 = nn.Linear(32 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.5)
self.z2_h = torch.nn.Parameter(Variable(torch.randn(64, 32, 8, 8, device=device)),
requires_grad=True)
def forward(self, x=None):
# flatten the features
if x is not None:
flat_feat = x.view(-1, 32 * 8 * 8)
else:
flat_feat = self.z2_h.view(-1, 32 * 8 * 8)
fc1 = self.fc1(flat_feat)
fc1 = F.relu(fc1)
fc1 = self.dropout(fc1)
fc2 = self.fc2(fc1)
fc2 = F.relu(fc2)
fc2 = self.dropout(fc2)
out = self.fc3(fc2)
out = F.log_softmax(out, 1)
return out
As it can be seen, I have defined variables self.z1_h
and self.z2_h
as the input to the 2nd and 3rd layers. These two variables will optimize to be the same as the output of their previous layers, z1
and z2
.
To optimize this network, I defined 3 different optimizers to update the parameters in alternation order: W (and bias) of the layers, and the 2 introduced variables self.z1_h
and self.z2_h
.
optimizer1 = optim.SGD(net1.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer2 = optim.SGD(net2.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer3 = optim.SGD(net3.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# net 1
# zero the parameter gradients
optimizer1.zero_grad()
# forward + backward + optimize
x1, z1 = net1(inputs)
loss1 = F.mse_loss(x1, z1)
loss1.backward(retain_graph=True)
optimizer1.step()
# net 2
# zero the parameter gradients
optimizer2.zero_grad()
# forward + backward + optimize
x2, z2 = net2()
# x2, z2, z2_h = net2(z1_h)
loss2 = F.mse_loss(z1, net2.z1_h) + F.mse_loss(x2, z2)
for param in net2.parameters():
param.requires_grad = True
net2.z1_h.requires_grad = False # z1_h constant
loss2.backward(retain_graph=True)
optimizer2.step()
for param in net2.parameters():
param.requires_grad = False
net2.z1_h.requires_grad = True # W and bias contant
loss2.backward(retain_graph=True)
optimizer2.step()
# net 3
# zero the parameter gradients
optimizer3.zero_grad()
# forward + backward + optimize
outputs = net3()
loss3 = F.nll_loss(outputs, labels) + F.mse_loss(z2, net3.z2_h)
for param in net3.parameters():
param.requires_grad = True
net3.z2_h.requires_grad = False # z2_h constant
loss3.backward(retain_graph=True)
optimizer3.step()
for param in net3.parameters():
param.requires_grad = False
net3.z2_h.requires_grad = True # W and bias contant
loss3.backward()
optimizer3.step()
But the problem is that the loss3
does not decrease after the first few iterations and the network won’t learn finally (classification accuracy of ~10% on Cifar-10)! I’m not very familiar with Pytorch and I was wondering if you guys have any thoughts on this. Thanks!