A really weird thing about the performance of pytorch models

Hi, everyone! I am confronted by a really weird phenomenon :hot_face: :hot_face: :hot_face:
I’ve written a pytorch code, whose backbone is listed below:(please keep in mind that target_feature_extraction_module and target_classificaiton_module have been adequately trained before so that the classification accuracy can reach 100% when used in conjunction with each other)

    for cur_epoch in range(epoch_num):
        target_feature_extraction_module.train()
        target_classification_module.train()
        source_feature_extraction_module.train()
        source_to_target_feature_trans.train()
        source_classification_module.train()  #All the five models required, which are all the instantiated objects of nn.Module


        target_data = list(enumerate(target_train_loader))
        source_data = list(enumerate(source_train_loader))
        rounds_per_epoch = min(len(target_data), len(source_data))
        for batch_idx in range(rounds_per_epoch):
            _, (target_train,target_label) = target_data[batch_idx]
            _, (source_train,source_label) = source_data[batch_idx]
            if with_nvidia:
                target_train = target_train.float().cuda()
                target_label = target_label.cuda()
                source_train = source_train.float().cuda()
            target_feature = target_feature_extraction_module(target_train)
            source_feature = source_feature_extraction_module(source_train)
            source_shape_changed_feature = source_to_target_feature_trans(source_feature)

            target_classification_result, target_before_last_linear = target_classification_module(target_feature)

            #just print the accuracy of classification
            y_predict = target_classification_result.detach().cpu().numpy()
            y_predict = np.argmax(y_predict, axis=1)
            acc = accuracy_score(y_predict,target_label.cpu().numpy())
            print(acc)  #always equals to 1(whenever in any batch, any circumstance) since the models has been  adequately trained before

            #the following sentence is the most horrible and even magical part for me
            #a , b = target_classification_module(source_shape_changed_feature)  
            target_classification_loss = nn.CrossEntropyLoss()(target_classification_result,target_label)   
            str_out = "Epoch:" + str(cur_epoch) +" batch_num:"+str(batch_idx)+" t_c_loss:"+str(target_classification_loss.data.cpu().numpy())
            print(str_out)
            target_classification_loss.backward()
            for the_optim in optimizer_list:
                the_optim.step()
            for the_optim in optimizer_list:
                the_optim.zero_grad()
        if cur_epoch%2 == 0:
            target_feature_extraction_module.eval()
            target_classification_module.eval()
            eval_model_traindata(target_feature_extraction_module,target_classification_module,target_train_loader,cur_epoch,with_nvidia) #this function will calculate the accuracy for training set, which operates the same as the print(acc) part above. So the output value of this function is expected to be 1 as well.

Please pay attention to the “#a , b = target_classification_module(source_shape_changed_feature)” in the middle of the code above. If no changes are made for the code above, the eval_model_traindata will output 1. However, If I change this sentence from a code comment to a line of actual code “a, b = target_classification_module(source_shape_changed_feature)”, the output value of eval_model_traindata function will be 0.3, which is really weird! Since the backpropagation proecess seems irrelevant to a and b, the accuracy and the parameters of the models shouldn’t be affected at all.
I can promise that the “eval_model_traindata” function can properly calculate the accuracy of a given model, which can also be verified by the fact that with “#a , b =target_classification_module(source_shape_changed_feature)” ,it can output 1 as expected.

I’ve run the code many times today. If “#a , b = target_classification_module(source_shape_changed_feature)” is annotated, the value of both acc and eval_model_traindata will be 1. However, under the condition of “a , b = target_classification_module(source_shape_changed_feature)”, only acc equals to 1 and the output of eval_model_traindata is a value lower than 0.5 :sob:

I sincerely appreciate all your help and suggestions!

Does your model include normalization layers such as BatchNorm? Even if you don’t backpropagate loss through a or b, the additional forward pass will update the statistics of normalization layers (e.g., the mean/variance) which could affect your results. I would check if calling .eval() on your model before a , b = target_classification_module(source_shape_changed_feature) followed by calling .train() on your model after that line produces the results that you expect. You should also consider wrapping the a, b = line in with torch.no_grad(): if you do not wish to perform a backward pass from those outputs as this will save memory.

1 Like

Thank you very much! :white_check_mark: Your suggestion really works and counts a lot, since my model does include BatchNorm.
However, I am still confused about the fact that why the batchnormalization can have such a dramatic effect on the performance of a model even if then model has already been trained with so many batches before and

a , b = target_classification_module(source_shape_changed_feature)

just feed the model with one batch of redundant data.

I’m not sure without understanding the details of your scenario (e.g., how many times the extra batch is run, or if it would be considered very far out of distribution). You could try inspecting the statistics of the normalization layers (e.g., the running_mean and runnning_var fields) and see if they change substantially.

Additionally, small differences can add up across many layers.

1 Like

Thank you very much!!!

Sorry to inflict myself on you again! Please allow me to make sure that even if calling .eval() on my model before a , b = target_classification_module(source_shape_changed_feature) followed by calling .train() on my model after that line, if I perform a backward pass from a and b later, the values of parameters in the model will also be affected except the values of normalization layers?

Yes, that should be the case as long as that part was not wrapped with with torch.no_grad().
You can sanity check whether things have changed by e.g., printing the sum of all of the running stats of the normalization layers of the model.

1 Like