TF-Keras to PyTorch Model conversion target and input size mismatch

I’m trying to convert a TensorFlow-Keras model to PyTorch, and encountered the following error:

Traceback (most recent call last):
File "model.py", line 480, in <module>
train_loop(model, device, train_dataloader, val_dataloader, optimizer, scheduler, model_name, epochs)
File "model.py", line 164, in train_loop
outputs = model(**inputs)
File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "model.py", line 104, in forward
loss = loss_fct(output, labels) # for logsoftmax
File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/user/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 1048, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "/home/user/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2693, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/home/user/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2397, in nll_loss
raise ValueError("Expected target size {}, got {}".format(out_size, target.size()))
ValueError: Expected target size (32, 3), got torch.Size([32])

The outputs for the target and labels are as follows:

output size: torch.Size([32, 256, 3])
labels size: torch.Size([32])
---Outputs---
tensor([[[-5.8398e-02, -2.1934e-02, -1.6847e-01],
         [ 5.8547e-03, -7.5034e-02, -1.8293e-01],
         [-1.3510e-01,  9.7023e-03, -7.3464e-02],
         ...,
         [ 2.6803e-03,  2.1820e-01, -6.6754e-03],
         [-1.2291e-01,  2.2271e-01, -1.3196e-01],
         [-5.0330e-02,  1.9224e-01, -1.0136e-02]],

        [[-1.8273e-01, -1.0688e-01, -9.2163e-02],
         [-1.3737e-01,  1.7531e-01,  9.0535e-02],
         [-1.0511e-02,  1.1870e-01, -2.0116e-01],
         ...,
         [-1.8598e-01,  2.6075e-01,  2.2549e-02],
         [-3.1560e-02,  2.8244e-01, -1.0405e-01],
         [-3.9868e-02,  3.9498e-01,  7.8026e-02]],

        [[-8.0235e-02,  9.4451e-02, -6.5826e-02],
         [-3.1021e-02,  1.2857e-01,  3.4519e-01],
         [ 2.1775e-02, -1.2628e-02,  1.4941e-01],
         ...,
         [-7.4938e-02,  1.8854e-01,  7.0910e-02],
         [-7.2939e-02,  1.2330e-01, -3.8597e-02],
         [-2.4675e-04,  1.5793e-01, -1.0282e-01]],

        ...,

        [[-3.1385e-01, -3.4715e-02, -1.4354e-01],
         [-2.3148e-01,  1.3139e-01,  1.2541e-02],
         [ 5.8533e-02,  1.1105e-01,  2.9554e-01],
         ...,
         [-1.7847e-02,  2.0214e-01, -2.2909e-02],
         [-8.9989e-03,  1.1894e-01, -5.3676e-02],
         [ 1.8850e-06,  2.0262e-01,  8.1679e-02]],

        [[-3.9956e-01, -2.4191e-02, -9.5190e-02],
         [-2.7695e-01,  1.4200e-01,  2.1383e-01],
         [ 1.7979e-02, -4.5267e-02,  1.9809e-01],
         ...,
         [ 1.3259e-02,  1.3560e-01, -1.6332e-01],
         [-7.5140e-02, -4.5789e-03, -1.5763e-01],
         [ 1.4075e-02,  9.4575e-02, -2.6298e-02]],

        [[-2.8567e-01,  8.0501e-03,  8.6712e-02],
         [-3.9667e-03,  9.8984e-02,  1.8594e-01],
         [-1.7376e-01, -1.6987e-02, -9.0209e-04],
         ...,
         [-1.8293e-01,  1.4832e-01, -7.7277e-02],
         [-2.8346e-01,  6.8508e-04, -7.5891e-02],
         [-5.7721e-02,  6.0411e-02, -7.9864e-02]]], device='cuda:0',
       grad_fn=<AddBackward0>)
---Labels---
tensor([1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 1, 0, 2, 2, 0, 2, 0, 2, 2,
        1, 0, 2, 0, 0, 2, 1, 1], device='cuda:0')

The original TensorFlow-Keras model is:

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    # Encoded token ids from BERT tokenizer.
    input_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="input_ids"
        )
    # Attention masks indicates to the model which tokens should be attended to.
    attention_masks = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="attention_masks"
        )
    # Token type ids are binary masks identifying different sequences in the model.
    token_type_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="token_type_ids"
        )
    # Loading pretrained BERT model.
    bert_model = transformers.TFBertModel.from_pretrained("bert-base-uncased")
    # Freeze the BERT model to reuse the pretrained features without modifying them.
    bert_model.trainable = False
    
    sequence_output, pooled_output = bert_model.bert(
        input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
        )
    # Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
    bi_lstm = tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(64, return_sequences=True)
        )(sequence_output)
    # Applying hybrid pooling approach to bi_lstm sequence output.
    avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
    max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
    concat = tf.keras.layers.concatenate([avg_pool, max_pool])
    dropout = tf.keras.layers.Dropout(0.3)(concat)
    output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
    model = tf.keras.models.Model(
        inputs=[input_ids, attention_masks, token_type_ids], outputs=output
        )
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss="categorical_crossentropy",
        metrics=["acc"],
        )

And the PyTorch model is:

class CustomModel(torch.nn.Module):
    def __init__(self, num_labels=3):
        super(CustomModel, self).__init__()
        self.base_model = BertModel.from_pretrained("bert-base-uncased") 
        self.lstm = torch.nn.LSTM(768, 64, bidirectional=True)
        self.globalavgpooling = torch.nn.AdaptiveAvgPool1d(128)
        self.globalmaxpooling = torch.nn.AdaptiveMaxPool1d(128)
        self.dropout = torch.nn.Dropout(p=0.3)
        self.linear = torch.nn.Linear(128, num_labels) 
        
    def forward(self, input_ids, attention_mask, labels):
        sequence_output, pooled_output = self.base_model(input_ids, attention_mask=attention_mask, return_dict=False)
        
        lstm_output, (h,c) = self.lstm(sequence_output) 
        avg_pool = self.globalavgpooling(lstm_output)
        max_pool = self.globalmaxpooling(lstm_output)
        
        concat = torch.cat((avg_pool, max_pool), dim=1)
        print(f'concat size: {concat.size()}')
        dropout = self.dropout(concat)
        print(f'dropout size: {dropout.size()}')
        
        output = self.linear(dropout)
        print(f'output size: {output.size()}')
        print(f'labels size: {labels.size()}')
        
        # compile model 
        loss_fct = torch.nn.CrossEntropyLoss()
        print('---Outputs---')
        print(output)
        print('---Labels---')
        print(labels)
        loss = loss_fct(output, labels) 
        
        return loss, output 

If it helps, the model summaries for TensorFlow-Keras model and PyTroch model are as follows:

(TensorFlow-Keras model) https://i.stack.imgur.com/apZ8X.png

(PyTorch model) https://i.stack.imgur.com/tToMm.png

Note that for the PyTorch model I did not declare the trainable layers, but that should not be the problem here.

Did I do something wrong in the mapping of the equivalent layers in PyTorch? I do notice that the total number of parameters are slightly different (differs by 128), but am unable to figure out why is that the case. If it helps, the batch size is 32.

The error is raised in loss function (either nn.CrossEntropyLoss or nn.NLLLoss) due to a shape mismatch.
I assume you are working on a multi-class classification and the target contains class indices.
Your current model output has the shape [32, 256, 3], which corresponds to [batch_size, nb_classes, seq_len], while the target is expected to have the shape [batch_size, seq_len] containing the class indices in the range [0, nb_classes-1]. Currently the target has the shape [batch_size=32], so it contains a target index for each sample while targets for each time step are also expected.

Thanks for the prompt feedback! Yes, its a multi-class classification problem.

With regards to the points you have raised, am I right to interpret that I should transform the target (which is of current shape [32] to the shape [32, 3]? Which means transforming

tensor([1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 1, 0, 2, 2, 0, 2, 0, 2, 2, 1, 0, 2, 0, 0, 2, 1, 1], device='cuda:0')

to something like

tensor([[0, 1, 0], [0, 0, 1], ..., [0, 1, 0]], device='cuda:0')

for the loss function to work?

No, your current transformation would one-hot encode the targets, which is not expected.
The issue is that your current model outputs logits for a time sequence while the target doesn’t contain labels for each time step.
Could you describe what the output shape of the model represents?

The output should reflect the number of classes which is 3. But here it is 256 (if I understood the documentation correctly). Did something go wrong in the layers?

I guess the “reduction” in the temporal dimension is missing in your model or alternatively you could only use the last time step from the LSTM output. Currently your model is returning logits for a sequence, which doesn’t fit the target.

Thanks for pointing that out, I made the changes to the model:

class CustomModel(torch.nn.Module):
    def __init__(self, num_labels=3):
        super(CustomModel, self).__init__()
        self.base_model = BertModel.from_pretrained("bert-base-uncased") 
        self.lstm = torch.nn.LSTM(768, 64, bidirectional=True, batch_first=True)
        self.globalavgpooling = torch.nn.AdaptiveAvgPool1d(128) 
        self.globalmaxpooling = torch.nn.AdaptiveMaxPool1d(128)
        self.dropout = torch.nn.Dropout(p=0.3)
        self.linear = torch.nn.Linear(128, num_labels) 
        
    def forward(self, input_ids, attention_mask, labels):
        sequence_output, pooled_output = self.base_model(input_ids, attention_mask=attention_mask, return_dict=False)
        
        lstm_output, (h,c) = self.lstm(sequence_output) 
        avg_pool = self.globalavgpooling(lstm_output)
        max_pool = self.globalmaxpooling(lstm_output)
        
        concat = torch.cat((avg_pool, max_pool), dim=1)
        dropout = self.dropout(concat)
        
        output = self.linear(dropout[:, -1, :]) # Added [:, -1, :]
        
        # compile model 
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(output, labels) 
        
        return loss, output

where I modified the outputs by using the last time step as you have suggested:

output = self.linear(dropout[:, -1, :]) # Added [:, -1, :]

after invoking batch_first=True in the LSTM layer.

Was this what you meant and implemented correctly?

Usually you would index the last time step in lstm_output, but your approach might also work.

Thanks for clarifying @ptrblck. I think its the same as doing the indexing at the output layer is akin to the last time step.

Right, that’s what I thought while looking at your model definition, but I haven’t verified it :stuck_out_tongue:

1 Like