I have a simple CNN network with 2 parts, extractor(CNN) and classifier(Fully connected layers).
Its a binary classification problem.
Scenario 1:
When I train using only source_dataset as shown in code below, the network trains very well.
code snip:
source_feature = extractor(source_eeg)
classifier_out = classifier(source_feature)
class_loss = criterion(classifier_out, source_main_labels)
loss = class_loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
output:
epoch: 1
source_train | class_loss: 0.689 class_accu: 0.544
epoch: 2
source_train | class_loss: 0.658 class_accu: 0.611
epoch: 3
source_train | class_loss: 0.626 class_accu: 0.679
epoch: 4
source_train | class_loss: 0.602 class_accu: 0.744
epoch: 5
source_train | class_loss: 0.568 class_accu: 0.802
epoch: 6
source_train | class_loss: 0.556 class_accu: 0.873
epoch: 7
source_train | class_loss: 0.567 class_accu: 0.892
epoch: 8
source_train | class_loss: 0.551 class_accu: 0.921
epoch: 9
source_train | class_loss: 0.518 class_accu: 0.937
Scenario 2:
When I even just forward another dataset along with the above code , the model does not learn.
Code:
source_feature = extractor(source_eeg)
classifier_out = classifier(source_feature)
class_loss = criterion(classifier_out, source_main_labels)
target_feature = extractor(target_eeg)
loss = class_loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
train output:
epoch: 1
source_train | class_loss: 0.692 class_accu: 0.521
epoch: 2
source_train | class_loss: 0.693 class_accu: 0.533
epoch: 3
source_train | class_loss: 0.691 class_accu: 0.521
epoch: 4
source_train | class_loss: 0.693 class_accu: 0.478
epoch: 5
source_train | class_loss: 0.693 class_accu: 0.470
epoch: 6
source_train | class_loss: 0.693 class_accu: 0.465
epoch: 7
source_train | class_loss: 0.693 class_accu: 0.465
epoch: 8
source_train | class_loss: 0.693 class_accu: 0.465
epoch: 9
source_train | class_loss: 0.693 class_accu: 0.465
The problem is that training of first dataset (source) is getting affected even by just having a line to forward pass the second dataset.
EDIT: In the above code i understand that there is no use of forwarding the 2nd dataset without backpropagating the loss, and i intend to do that in future, but my query was that why just forwarding the second dataset is affecting the learning of the first dataset as well.
Thanks for the help!