This is my training function:
def ewc_loss(logits, targets, lamda, fishers, prev_opt_thetas, cur_thetas):
loss_l = 0
for i in range(len(fishers)):
fisher = fishers[i]
prev_opt_theta = prev_opt_thetas[i]
cur_theta = cur_thetas[i]
loss_l = loss_l+ torch.sum(fisher * ((prev_opt_theta-cur_theta)**2))
return lamda/2 * loss_l
def train_ewc(model, device, train_loader, optimizer, base_loss_fn,
lamda, fishers, prev_opt_thetas, epoch, other_loaderA, other_loaderB, description=""):
model.train()
loss_train = 0
loss_ewc_total=0
loss_cross_b=0
pbar = tqdm(train_loader)
pbar.set_description(description)
# freeze_layers(0, model)
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
cur_thetas = list(model.parameters())
optimizer.zero_grad()
logits = model(inputs)
loss_crossentropy = base_loss_fn(logits, targets)
loss_ewc = ewc_loss(logits, targets, lamda, fishers,
prev_opt_thetas, cur_thetas)
total_loss_combined = loss_crossentropy + loss_ewc
# total_loss_combined= torch.Tensor([0]).type_as(loss_crossentropy.data)
print("total loss is " + str(total_loss_combined)+" EWC penalty is " + str(loss_ewc)+" cross entropy is " + str(loss_crossentropy))
loss_train += total_loss_combined.item()
loss_ewc_total += loss_ewc.item()
loss_cross_b +=loss_crossentropy.item()
total_loss_combined.backward()
print("loss grad is ", total_loss_combined)
optimizer.step()
# for param in list(model.parameters())[0:5]:
# print(param.grad)
loss_testB, acc_testB = test(model,device, other_loaderB, base_loss_fn,description="Test on task B")
print()
loss_testB, acc_testB = test(model,device, other_loaderA, base_loss_fn,description="Test on task A")
print()
So, when I set my total_loss_combined = loss_crossentropy - loss_crossentropy which is 0, my model results in significantly changed accuracies on task A and B. Accuracy on task A goes from 71% to 28% and that of B goes from 56% to 72% just after 1 epoch of training on B. This is really strange since my loss function is 0.
The strangest part being that when my total_loss_combined is the one stated above, I get the same results as mentioned above even though total_loss_combined is not zero and nor is any of loss_crossentropy or loss_ewc.