Model before/after loading weight totally different?

Hello,

I struggled all the day with an issue I never had before (I used usually images data)
My model performs well on the validation data of the imdb dataset during the training. I save then the weights.
If I want to evaluation again on the same validation data, then the results are worse a lot (I am using model.eval() before the forward function). So I do not really understand what’s going on.
I tried to print the results inside the different layers, and I saw that, the results are different starting from Linear layer. I was wondering if I have done something wrong in my model :

class SequenceNet(nn.Module):
    def __init__(self, embedding_matrix, num_classes, hidden_sizes=64, padding_value=1):
        super(SequenceNet, self).__init__()
        embed_size = embedding_matrix.shape[1]
        LSTM_UNITS = hidden_sizes
        DENSE_HIDDEN_UNITS = LSTM_UNITS * 4
        self.embedding = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix), freeze=False, sparse=True, padding_idx=padding_value)
        
        self.norm_embedding = nn.LayerNorm(embed_size)
        self.embedding_dropout = SpatialDropout(0.3)
        
        self.lstm1 = nn.LSTM(embed_size, LSTM_UNITS, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(LSTM_UNITS * 2, LSTM_UNITS, bidirectional=True, batch_first=True)
        
        self.norm_seq = nn.LayerNorm(LSTM_UNITS * 2 * 2)
        

        self.fc1 = nn.Linear(DENSE_HIDDEN_UNITS, DENSE_HIDDEN_UNITS)
        self.fc2 = nn.Linear(DENSE_HIDDEN_UNITS, DENSE_HIDDEN_UNITS)
        
        self.out = nn.Linear(DENSE_HIDDEN_UNITS, num_classes)
        
    def extract_mean(self, x_pack, lengths):
        h_t, _ = pad_packed_sequence(x_pack, batch_first=True, padding_value=0.0)
        #print("extract mean", h_t.shape, h_t.sum(dim=1).shape,  lengths.shape)
        h_mean = h_t.sum(dim=1)/lengths.view(-1,1)
        #print(h_mean.shape)
        return h_mean
    
    def extract_max(self, x_pack):
        h_t, _ = pad_packed_sequence(x_pack, batch_first=True, padding_value=-float("inf"))
        
        h_max = h_t.max(dim=1)[0]
        
        return h_max
    
    def forward(self, x, lengths):
        x = self.embedding(x)
        x = self.norm_embedding(x)
        x = self.embedding_dropout(x)
        
        x_pack = pack_padded_sequence(x, batch_first=True, lengths=lengths)
        h_lstm1, _ = self.lstm1(x_pack)
        h_lstm2, _ = self.lstm2(h_lstm1)
        #print("SHAPE :")
        # global average pooling and unpack seq
        avg_pool = self.extract_mean(h_lstm2, lengths)
        #print(avg_pool.shape)
        # global max pooling and unpack seq
        max_pool = self.extract_max(h_lstm2)
        #print(max_pool.shape)
        h_conc = torch.cat((max_pool, avg_pool), 1)
        #print(h_conc.shape)
        h_conc = self.norm_seq(h_conc)
        h_conc_linear1  = F.relu(self.fc1(h_conc))
        #h_conc_linear1 = h_conc + h_conc_linear1
        
        h_conc_linear2  = F.relu(self.fc2(h_conc_linear1))
        
        hidden = h_conc_linear2
        
        result = self.out(hidden)
        
        
        return result
    
    def _forward(self, x, lengths):
        x = self.embedding(x)
        print("embedding :", x)
        x = self.norm_embedding(x)
        x = self.embedding_dropout(x)
        print("norm_embedding :", x)
        x_pack = pack_padded_sequence(x, batch_first=True, lengths=lengths)
        h_lstm1, _ = self.lstm1(x_pack)
        h_lstm2, _ = self.lstm2(h_lstm1)
        #print("SHAPE :")
        # global average pooling and unpack seq
        avg_pool = self.extract_mean(h_lstm2, lengths)
        print("avg pool :", avg_pool)
        #print(avg_pool.shape)
        # global max pooling and unpack seq
        max_pool = self.extract_max(h_lstm2)
        print("max_pool" , max_pool)
        h_conc = torch.cat((max_pool, avg_pool), 1)
        print("h_conc :", h_conc)
        h_conc = self.norm_seq(h_conc)
        print("h_conc_norm :", h_conc)
        h_conc_linear1  =  F.relu(self.fc1(h_conc))
        print("h_conc_1 :", h_conc_linear1)
        #h_conc_linear1 = h_conc + h_conc_linear1
        print("h_conc_1_res :", h_conc_linear1)
        h_conc_linear2  = F.relu(self.fc2(h_conc_linear1))
        
        print("h_conc_2 :", h_conc_linear2)
        
        hidden =  h_conc_linear2
        print("hidden :", hidden)
        result = self.out(hidden)
        print("result:", result)
        
        return result

The _forward function is just to display in order to debug.

The entire script is attached are on this address : imdb - Google Drive
When

debug = True

, it means I just want to evaluate the model on the validation data. (I am using the IMDB dataset)

In the link I have shared, I also saved the two different print of different layers for one forward using :

A = next(iter(val_dataloader2))
preds = model._forward(A["text"].to(device), A["lengths"].to(device))

The results are different after the Layer Norm “h_conc”. Therefore, the issue seems to come from the linear layer but I have no idea why.

Does anyone have an idea why I got different results after loading the weights ?
I am using pytorch 1.4 on windows 10 with cuda 10.1 . I am with anaconda and using spyder 4 as an IDE.

EDIT : I have done more test :

  • removing the linear1 and linear2 seem to solve the issue, but I do not understand why.
  • the issue seems to come from apex, if I trained without apex, it works well.

apex.amp uses monkey patching to transform the data to the appropriate dtype in the forward method.
Based on your code, it seems you are using a custom _forward method, which would be invisible for apex. Could you stick to the vanilla forward definition and check the results again?

@ptrblck Hi, thank you for your reply. I have just checked, and it does not seem to change anything unfortunately. Just to add some information, I did not use the ._forward function for the evaluation on the validation data. I use this function just for trying to see starting which layer the results became different (and it seems to be the linear1 layer)

Thanks for the update.
How large is the difference after the linear layer and which opt_level are you using?

@ptrblck thank you for your reply.
I am using the “O1” opt_level.
Regarding the linear layer, it depends but the difference can be huge. For instance the first linear layer, I can have value around 10e1 or 10e2 insteand of 10-e1 and for the last layer, it is worse : the output of the last layer can be around 10e3 or 10e4 instead of usually being around 10-1.

@ptrblck I found finally the issue. It cames when I tried to compute the gradient with the backward() function. I forgot to use amp.scale_loss. But it makes a weird behaviors because the training works well, until I load again the checkpoint …

Problem solved !

@ptrblck also see an issue with the sparseAdam optimizers (I have a sparse embedding in my model) with apex, is there a way to solve that ?

optimizers = [torch.optim.SparseAdam(params=[list(model.parameters())[0]], lr=lr),
                              torch.optim.Adam(params=list(model.parameters())[1:], lr=lr, amsgrad=True) ]

with amp.scale_loss(loss, optimizers) as scaled_loss:
                    scaled_loss.backward()

The error :

scaled_loss.backward()

File “C:\Users\S\Anaconda3\envs\pytorch\lib\contextlib.py”, line 119, in exit
next(self.gen)

File “C:\Users\S\Anaconda3\envs\pytorch\lib\site-packages\apex\amp\handle.py”, line 123, in scale_loss
optimizer._post_amp_backward(loss_scaler)

File “C:\Users\S\Anaconda3\envs\pytorch\lib\site-packages\apex\amp_process_optimizer.py”, line 241, in post_backward_no_master_weights
post_backward_models_are_masters(scaler, params, stashed_grads)

File “C:\Users\S\Anaconda3\envs\pytorch\lib\site-packages\apex\amp_process_optimizer.py”, line 120, in post_backward_models_are_masters
scale_override=grads_have_scale/out_scale)

File “C:\Users\S\Anaconda3\envs\pytorch\lib\site-packages\apex\amp\scaler.py”, line 117, in unscale
1./scale)

File “C:\Users\S\Anaconda3\envs\pytorch\lib\site-packages\apex\multi_tensor_apply\multi_tensor_apply.py”, line 30, in call
*args)

RuntimeError: sparse tensors do not have is_contiguous (is_contiguous at …\aten\src\ATen\SparseTensorImpl.cpp:55)
(no backtrace available)

No, sparse optimizers are not supported at the moment.

Good to hear you’re figured out the initial issue!

1 Like