This is my training function:
def ewc_loss(logits, targets, lamda, fishers, prev_opt_thetas, cur_thetas): loss_l = 0 for i in range(len(fishers)): fisher = fishers[i] prev_opt_theta = prev_opt_thetas[i] cur_theta = cur_thetas[i] loss_l = loss_l+ torch.sum(fisher * ((prev_opt_theta-cur_theta)**2)) return lamda/2 * loss_l def train_ewc(model, device, train_loader, optimizer, base_loss_fn, lamda, fishers, prev_opt_thetas, epoch, other_loaderA, other_loaderB, description=""): model.train() loss_train = 0 loss_ewc_total=0 loss_cross_b=0 pbar = tqdm(train_loader) pbar.set_description(description) # freeze_layers(0, model) for inputs, targets in pbar: inputs, targets = inputs.to(device), targets.to(device) cur_thetas = list(model.parameters()) optimizer.zero_grad() logits = model(inputs) loss_crossentropy = base_loss_fn(logits, targets) loss_ewc = ewc_loss(logits, targets, lamda, fishers, prev_opt_thetas, cur_thetas) total_loss_combined = loss_crossentropy + loss_ewc # total_loss_combined= torch.Tensor().type_as(loss_crossentropy.data) print("total loss is " + str(total_loss_combined)+" EWC penalty is " + str(loss_ewc)+" cross entropy is " + str(loss_crossentropy)) loss_train += total_loss_combined.item() loss_ewc_total += loss_ewc.item() loss_cross_b +=loss_crossentropy.item() total_loss_combined.backward() print("loss grad is ", total_loss_combined) optimizer.step() # for param in list(model.parameters())[0:5]: # print(param.grad) loss_testB, acc_testB = test(model,device, other_loaderB, base_loss_fn,description="Test on task B") print() loss_testB, acc_testB = test(model,device, other_loaderA, base_loss_fn,description="Test on task A") print()
So, when I set my total_loss_combined = loss_crossentropy - loss_crossentropy which is 0, my model results in significantly changed accuracies on task A and B. Accuracy on task A goes from 71% to 28% and that of B goes from 56% to 72% just after 1 epoch of training on B. This is really strange since my loss function is 0.
The strangest part being that when my total_loss_combined is the one stated above, I get the same results as mentioned above even though total_loss_combined is not zero and nor is any of loss_crossentropy or loss_ewc.