I have a complex model that calculates the low-rank matrix and try to minimize it while training CNN. The training wrapper is the following:

def train_generalization(args, modelc, model1, model2, device, train_loader_combined, optimizer, epoches, criterion,batch_size):

for epoch in range(epoches): # loop over the dataset multiple times

running_loss = 0.0

for i, data in enumerate(train_loader_combined, 0):

input_combined,labels_combined, flag = data

labels_combined,index =torch.sort(labels_combined)

input_combined=input_combined[index]

flag = flag[index]

input_combined, labels_combined, flag = input_combined.to(device), labels_combined.to(device), flag.to(device)

optimizer.zero_grad()

Hc, outputs1 = modelc(input_combined) # domain invariant representation

Lg = criterion(outputs1, labels_combined)

Hs1, outputs2 = model1(input_combined[flag==1])

Hs2, outputs3 = model2(input_combined[flag==2])

max_iteration = 1000

if (epoch >= 2):

max_iteration = 1000

Hs = torch.cat((Hs1, Hs2), 0)

labels_Hs = torch.cat((labels_combined[flag==1], labels_combined[flag==2]), 0)

labels_Hs,index = torch.sort(labels_Hs)

Hs = Hs[index]

Q = get_Q(labels_combined, labels_combined, batch_size)

Z,ZZ,E = calculate_Z(torch.transpose(Hc,0,1),torch.transpose(Hs,0,1), Q, device, batch_size)

Lr = get_nuc_norm(Z)+ get_fib_norm(Z-Q)

Lr = Lr.to(device)

loss = Lr

loss.backward()

optimizer.step()

print(Lr)

My CNN is as the following:

```
def __init__(self):
super(model_gen, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(20 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x= self.conv2_drop(x)
x = x.view(-1, self.num_flat_features(x))
x = self.fc1(x) # matrix of 20*5*,120
x1 = self.fc2(x) # vector of 84
x2 = self.fc3(x1) # vector of 10 which are the number of classes
x4 = F.relu(x2)
```

For some reason, after the first iteration, the fc1 and fc2 wights gives me NAN while fc3 gives me normal weights

Can anybody direct me on how to figure out the problem?