Load batch to GPU problem in pytorch using BERT model

I want to implement the BERT model for a classification task based on this tutorial Tutorial: Fine-tuning BERT for Sentiment Analysis - by Skim AI

Training step

def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    """
    Train the BertClassifier model.
    """
    # Start training loop
    print("Start training...\n")
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0

        # Put the model into the training mode
        model.train()

        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

            # Zero out any previously calculated gradients
            model.zero_grad()

            # Perform a forward pass. This will return logits.
            logits = model(b_input_ids, b_attn_mask)

            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()

            # Perform a backward pass to calculate gradients
            loss.backward()

            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()

            # Print the loss values and time elapsed for every 20 batches
            if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 20 batches
                time_elapsed = time.time() - t0_batch

                # Print training results
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")

                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)

        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            val_loss, val_accuracy = evaluate(model, val_dataloader)

            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
        print("\n")
    
    print("Training complete!")

After training the model I want to refit it again in both training and validation sets. I should mention that the texts_train and texts_valid are lists.

# Train our model on the entire training data
# Concatenate the train set and the validation set
full_train_data = torch.utils.data.ConcatDataset([texts_train, texts_valid])
full_train_sampler = RandomSampler(full_train_data)
full_train_dataloader = DataLoader(full_train_data, sampler=full_train_sampler, batch_size=32)

# Train the Bert Classifier on the entire training data
set_seed(42)
bert_classifier, optimizer, scheduler = initialize_model(epochs=2)
train(model=bert_classifier, train_dataloader=full_train_dataloader, val_dataloader=None, epochs=2, evaluation=False)

And I have the following error

AttributeError: 'str' object has no attribute 'to'

See below the full error
enter image description here

Based on the error message it seems that batch contains a str object (besides tensors or other objects), which can’t be pushed to the GPU. Could you check the content of batch and call to(device) only on tensors?

Thank you so much for your help. Is it possible to be more specific as for the coding part? I am a begginer in pytorch, thus I’m not so familiar with the concepts.

Check what types are returned by the DataLoader via:

for batch in full_train_dataloader:
    for b in batch:
        print(type(b))

and make sure they are all tensors.
Based on the error message:

tuple(t.to(device) for t in batch)

crashes since t is a str object while a tensor is expected.

The output is the following

Input: for batch in full_train_dataloader:
    for b in batch:
        print(type(b))

Output:

<class 'torch.Tensor'>
Input: tuple(t.to(device) for t in batch)
Output:
(tensor([[  101,  3183,  6656,  1288,   121,  2981,  3070,   349,  3824,  7192,
            613,   273, 11811,  4858,   435,   358,   708,  3234, 24109,  2504,
          11753,   121,   435,   349,   802,  7192,   344,   345,  3865, 18340,
          13825, 12820,  9857,   358, 24860,   273,  3712,   273,  1463,  6286,
           1873,   774,   588,  9857, 10423,  1571,   373,   121,   509,  2947,
           6461,  2489,  1287,  2550,   530,   346,   433,  6302,   247,   281,
            280,  1250,   614,  1742,   414, 22267,   247, 11883,   515,  1288,
            239,   440,  1508,  1383,   355,  8489,  1723,  4385,   558,  8518,
            278,  1354,   683,   273, 12948,   239,   273, 18340,   414,  2359,
          12948,   770,   267,   564,  9857,  1050,  3234,   749,  4624,   278,
            344, 10240,  2746,  1622,  2201,   278,   121,   374,   344,   345,
            349, 22321,   278,  1393, 29187,  1475,  3102,   273,  3186, 11811,
           4858, 14588,   247,  6911,  8089,   483,  4651,   278,   247,  6911,
            237,   261,   281,  1423,  6373,   278,  1294,  1604, 16231,   121,
            344,   102,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [  101,  4570,  6509,   273,   435,   349,   518,   939, 13508,   569,
           4999,  1678,   267,   381,  2981,  1372,  7511,   278,  1060,   247,
            281,   349,  6171,  4905,  6029,   769, 13056,  2447,   497,   433,
           1136,  3815, 24109,  2504, 10765,   483,   278,   121,  2344,   236,
            437,  2359,  1792,   566,  6424,  5409,   661,   273,  1354,   344,
          11797,   267,   668,   273,  2416,  1413,   281,  2928,   410,   273,
            269,   350, 13715,   273,  1678,  4385,   421,   346,  5101,  4201,
           4961,   281, 17700,   390, 13958,   273,  2330,  6040,   273,   121,
            247,   351,  2108,   269, 18687,   278,   691,  1436,   281,  1413,
           4872,  3744, 11811,   468,   344,   346,  1423, 10760,  3001,   381,
            367,   643,   832,  1946, 10986,   278,  9038, 11072,  1413,   281,
           1857,   904,   643,  8092,  6654,  9857,   121,   100,   121,   518,
          21391,   389,   121, 21391,  1032,  6874,  5532,   344,  3468,  5532,
            382,   350,  1183,   273,   352, 13120,   273,   121,  9038,   433,
            344, 23006,   378,   387,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [  101,   358,  4668, 10009,   781, 10290,   273,  1200, 16760,   121,
            363,   433, 10373,  2658,   269,   929,   351,   273,  9656,   344,
          12547,   269,  2378,  9483,   236,   257, 15158,  6167,   404,  3156,
          14553,   387, 17238,  8850,   742,   278,  3885, 14229,  7789,   344,
            452,  4824,  6357,  9857, 29526,   373,   352,  6636,   278,   247,
           7151,  1270,   121,   247,   281,   441,  5646,  3679,   605,   744,
            437,   868,   868,  2669,  5595,   278, 16372,   265,  6686,   814,
            344,   433,  2927,   605,   518,   358, 22972,   373, 17621,   278,
           1294,  1403,  3058,   420,   404,   606,   344,  5463, 34931,  2313,
            362, 17426,  4430,   273,   352,  4791,   421,   355,   278,  1643,
            624,   742,   278,   344,  1995,   546,   421,   345,   349,  3664,
            613,   273,  1678,   421, 16940,   483,   344, 11989,   285, 26157,
            273,  3815,  2411, 16514,  4113,   273,   344, 28701,   121,   346,
            433, 17426,  4430,   281,   355,   278,   351,   273,  1643,   624,
            452,   466,   269,  5779, 18123,   281,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [  101, 19828,  3998, 14696,  9968,   904,   381,   367,  3343,  1571,
           4858,   247,  8009,   278, 22267,  1802, 19828,   269,   285,   414,
          28672,  1191,  5758,  2179,  9038,  3815, 17657,   483,   278, 16252,
            934,   121,   655,   247,   281,   273, 29478,  1032,  4569,  1721,
            345, 19828,  3998,   239,   271,  1742,   494,   382,   362,   713,
            513,   273,  1821,   729,  1112,   410,  1199,   255,  3573,   373,
            643,   376, 31409,   273,   121,   247,   281,  2947,   440, 27967,
           3815,   924,   273,  4569,  3810,  1825,   273, 16252,   934,   239,
            100,   595,  7314,  1571,   267, 10580,   546,   464, 11367,   433,
           1698,  9857,   358,  4067,   350,  4020,   273,   358,  1995,  3344,
           7385,   255,  3573,   373,   121,  2273,  2947,   355,  1821,  1340,
            990,   358,   248,  2351,  1894,   464,   350,   358,   255,  3573,
            373, 12263,   273,   437,  7312,   827,   121,   355,   351,   273,
           2947,  1821,  1340,   990, 14397, 16907,   273,   357,  2706,   483,
            362,  5075,   273,   355,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]], device='cuda:0'),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'),
 tensor([9, 9, 3, 9], device='cuda:0'))

This single batch seems to contain the expected tensors, so you would either have to check the entire dataset or verify where the str is coming from in your code.

1 Like