I have a complex model that calculates the low-rank matrix and try to minimize it while training CNN. The training wrapper is the following:
def train_generalization(args, modelc, model1, model2, device, train_loader_combined, optimizer, epoches, criterion,batch_size):
for epoch in range(epoches): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(train_loader_combined, 0):
input_combined,labels_combined, flag = data
labels_combined,index =torch.sort(labels_combined)
input_combined=input_combined[index]
flag = flag[index]
input_combined, labels_combined, flag = input_combined.to(device), labels_combined.to(device), flag.to(device)
optimizer.zero_grad()
Hc, outputs1 = modelc(input_combined) # domain invariant representation
Lg = criterion(outputs1, labels_combined)
Hs1, outputs2 = model1(input_combined[flag==1])
Hs2, outputs3 = model2(input_combined[flag==2])
max_iteration = 1000
if (epoch >= 2):
max_iteration = 1000
Hs = torch.cat((Hs1, Hs2), 0)
labels_Hs = torch.cat((labels_combined[flag==1], labels_combined[flag==2]), 0)
labels_Hs,index = torch.sort(labels_Hs)
Hs = Hs[index]
Q = get_Q(labels_combined, labels_combined, batch_size)
Z,ZZ,E = calculate_Z(torch.transpose(Hc,0,1),torch.transpose(Hs,0,1), Q, device, batch_size)
Lr = get_nuc_norm(Z)+ get_fib_norm(Z-Q)
Lr = Lr.to(device)
loss = Lr
loss.backward()
optimizer.step()
print(Lr)
My CNN is as the following:
def __init__(self):
super(model_gen, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(20 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x= self.conv2_drop(x)
x = x.view(-1, self.num_flat_features(x))
x = self.fc1(x) # matrix of 20*5*,120
x1 = self.fc2(x) # vector of 84
x2 = self.fc3(x1) # vector of 10 which are the number of classes
x4 = F.relu(x2)
For some reason, after the first iteration, the fc1 and fc2 wights gives me NAN while fc3 gives me normal weights
Can anybody direct me on how to figure out the problem?