...stack expects each tensor to be equal size... problem with, supposedly, integer labels

Hello, I am stuck at preparing a dataset class for multi-class model, that should use MultiClassCross Entropy loss.

Note that data_row.text is a string, and data_row.y#[LABEL_COLUMNS] is an integer, and in my test data ys look like this:

y
0     7
1    59
2    34
dtype: int64

Code:

## Dataset

#del FakeNewsDataset

class FakeNewsDataset(Dataset):
    
    def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 2048
    ):
        self.tokenizer = tokenizer
        self.data = data
        self.max_token_len = max_token_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]
        text = data_row.text
        labels = data_row.y#[LABEL_COLUMNS]
        print(labels)
        print(type(labels))
        
        encoding = self.tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_token_len,
          return_token_type_ids=False,
          padding="max_length",
          truncation=True,
          return_attention_mask=True,
          return_tensors='pt',
        )
        
        return dict(
          #text=text,
          input_ids=encoding["input_ids"].flatten(),
          attention_mask=encoding["attention_mask"].flatten(),
          labels=torch.LongTensor(labels)
        )

## run sample through the model

fnd = FakeNewsDataset(
  train_df,
  tokenizer
)

sample_batch = next(iter(DataLoader(fnd, batch_size = 20, num_workers = 1)))

RuntimeError: stack expects each tensor to be equal size, but got [1] at entry 0 and [2] at entry 2

This error happens when I change batch_size >= 3 only, otherwise it works good.

So, it always helps to formulate a problem in text.

Fixed it this way:

labels=torch.LongTensor([labels]).squeeze()

I think I was forcing the pytorch tensors to generate zero, one, or two random integers (after 0, 1, 2 label codes). My newbe mistakes. And the .squeeze() was later needed to get the right shape for the MultiClass Cross Entropy’s targets.