Facing error: `Expected input batch_size (72) to match target batch_size (50)."

Hi All,
I am trying to extract name from phonetic transcription and trying to train the model.
Following is the code:

#!/usr/bin/env python
# coding: utf-8
# Lets import the relevant things for the model. 
# In[48]:


import torch
from torch.utils.data import DataLoader,Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel,GPT2Config, AdamW
from torch.nn.utils.rnn import pad_sequence
# In[49]:
def collate_batch(batch):
    # Sort the batch by input sequence length
    batch = sorted(batch, key=lambda x: len(x['input_ids']))
    # Pad sequences to have the same length within the batch
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True)
    target_ids = pad_sequence([item['target_ids'] for item in batch], batch_first=True)
    return {'input_ids': input_ids, 'target_ids': target_ids}


# In[50]:
def fine_tune_model(train_dataset, model, tokenizer, epochs=3, learning_rate=1e-5):
    # Set the device (CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # Create DataLoader for training dataset
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True,  collate_fn=collate_batch)
    # Set up optimizer and loss function
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    # Fine-tune the model
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].squeeze().to(device)
            target_ids = batch['target_ids'].squeeze().to(device)            
              # Shift target_ids by one position to the right
            shifted_target_ids = torch.roll(target_ids, shifts=1, dims=1)
              # Check if tokenizer.pad_token_id is None and replace it with a valid token ID
            if tokenizer.pad_token_id is None:
                pad_token_id = 0  # Replace with a valid non-padding token ID
            else:
                pad_token_id = tokenizer.pad_token_id

            shifted_target_ids[:, 0] = pad_token_id  # Set the first position to the padding token

            # Forward pass
            outputs = model(input_ids, labels=shifted_target_ids)
            loss = outputs.loss

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {average_loss}")

    print("Fine-tuning complete!")


# In[51]:
def extract_phonetic_transcription(model, tokenizer, input_text):
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Generate output from the model
    with torch.no_grad():
        output_ids = model.generate(input_ids)

    # Decode the output and extract the relevant part
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Assuming the relevant part starts after the first space
    relevant_part = output_text.split(' ', 1)[1]

    return relevant_part

# In[52]:
class PhoneticTranscriptionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Tokenize input and target sequences
        input_text = sample['input']
        target_text = sample['target']

        input_ids = self.tokenizer.encode(input_text, return_tensors='pt').squeeze()
        target_ids = self.tokenizer.encode(target_text, return_tensors='pt').squeeze()

        return {'input_ids': input_ids, 'target_ids': target_ids}


# In[53]:


# Example usage
if __name__ == "__main__":
    
    # Your dataset with pairs of input and target sequences
    training_data = [
        {'input': "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i", 'target': "s ʌ m i ɹ z o w ʃ i"},
        {'input': "m a j n e j m ɪ z b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i", 'target': "b ʌ s o w ɹ ɑ d͡ʒ ɡ ʊ l i"}
        # Add more examples as needed
    ]
    
 

    # Load pre-trained GPT-2 model and tokenizer
    model_name = 'gpt2'
    model = GPT2LMHeadModel.from_pretrained(model_name)
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    
    # Create an instance of the dataset
    train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)

    # Fine-tune the model on your dataset (provide your own implementation)
    # train_dataset = ...  # Your prepared dataset
    fine_tune_model(train_dataset, model, tokenizer)

    # Example input
    input_phonetic_transcription = "m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i"

    # Extract relevant part using the model
    relevant_part = extract_phonetic_transcription(model, tokenizer, input_phonetic_transcription)

    print("Input Phonetic Transcription:", input_phonetic_transcription)
    print("Extracted Relevant Part:", relevant_part)

I am getting following error:

ValueError Traceback (most recent call last)
Cell In[53], line 23
19 train_dataset = PhoneticTranscriptionDataset(training_data, tokenizer)
21 # Fine-tune the model on your dataset (provide your own implementation)
22 # train_dataset = … # Your prepared dataset
—> 23 fine_tune_model(train_dataset, model, tokenizer)
25 # Example input
26 input_phonetic_transcription = “m a j n e j m ɪ z s ʌ m i ɹ z o w ʃ i”

Cell In[50], line 33, in fine_tune_model(train_dataset, model, tokenizer, epochs, learning_rate)
30 shifted_target_ids[:, 0] = pad_token_id # Set the first position to the padding token
32 # Forward pass
—> 33 outputs = model(input_ids, labels=shifted_target_ids)
34 loss = outputs.loss
36 # Backward pass and optimization

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\transformers\models\gpt2\modeling_gpt2.py:1108, in GPT2LMHeadModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1106 # Flatten the tokens
1107 loss_fct = CrossEntropyLoss()
→ 1108 loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1110 if not return_dict:
1111 output = (lm_logits,) + transformer_outputs[1:]

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don’t have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = ,

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
1173 def forward(self, input: Tensor, target: Tensor) → Tensor:
→ 1174 return F.cross_entropy(input, target, weight=self.weight,
1175 ignore_index=self.ignore_index, reduction=self.reduction,
1176 label_smoothing=self.label_smoothing)

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\functional.py:3029, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3027 if size_average is not None or reduce is not None:
3028 reduction = _Reduction.legacy_get_string(size_average, reduce)
→ 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

ValueError: Expected input batch_size (72) to match target batch_size (50).

Can someone help me with what is wrong here and why even after padding input batch size does not match with target?