RuntimeError: CUDA error: device-side assert triggered when fine-tune Bert

C:/w/b/windows/pytorch/aten/src/THC/THCTensorIndex.cu:272: block: [178,0,0], thread: [31,0,0] Assertion srcIndex < srcSelectDimSize failed.
Traceback (most recent call last):
File “C:/Users/lhuang93/PycharmProjects/pythonProject/train.py”, line 190, in
F = train(config, ‘ddi_e-5.log’)
File “C:/Users/lhuang93/PycharmProjects/pythonProject/train.py”, line 130, in train
train_loss, train_pred = run_iter(batch=train_batch, is_training=True)
File “C:/Users/lhuang93/PycharmProjects/pythonProject/train.py”, line 102, in run_iter
logits = model(input_ids, attention_mask, token_type_ids, label, e1_mask, e2_mask)
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “C:\Users\lhuang93\PycharmProjects\pythonProject\model.py”, line 60, in forward
outputs = self.bert(input_ids, attention_mask=attention_mask,
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “D:\Anaconda\Anaconda\lib\site-packages\transformers\modeling_bert.py”, line 829, in forward
embedding_output = self.embeddings(
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “D:\Anaconda\Anaconda\lib\site-packages\transformers\modeling_bert.py”, line 211, in forward
token_type_embeddings = self.token_type_embeddings(token_type_ids)
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\modules\sparse.py”, line 124, in forward
return F.embedding(
File “D:\Anaconda\Anaconda\lib\site-packages\torch\nn\functional.py”, line 1814, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: CUDA error: device-side assert triggered

I had the same error in one of my machine translation projects.
I solved the error by removing the very long sentences in my datasets.

Maybe when I see your code I will be able to help you better.

But the max length of my sentences is no more than 300, so I set the max length as 300.

The following code is my data process code for relation extraction. I added the special token ,, , for indentifing the entities.

import json
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
import random
from torch.utils.data import DataLoader, TensorDataset
import os
import logging
import torch
import numpy as np
from tqdm import tqdm
import config

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)-8s %(message)s')

model_name = 'bert-base-uncased'



tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_name, do_lower_case=config.do_lower_case)
tokenizer.add_special_tokens({"additional_special_tokens": ["<e1>", "</e1>", "<e2>", "</e2>"]})



label_path='label2id.json'
label2id = json.load(open(label_path, 'r'))
NA_id = label2id['NA']

def tokenizer_entity(sent, headword_pos, tailword_pos):
    sent = list(sent)
    if headword_pos[1]<tailword_pos[0]:
        sent.insert(headword_pos[0], "<e1>")
        sent.insert(headword_pos[1]+1, "</e1>")
        sent.insert(tailword_pos[0]+2, "<e2>")
        sent.insert(tailword_pos[1]+3, "</e2>")
    else:
        sent.insert(tailword_pos[0], "<e2>")
        sent.insert(tailword_pos[1] + 1, "</e2>")
        sent.insert(headword_pos[0] + 2, "<e1>")
        sent.insert(headword_pos[1] + 3, "</e1>")
    sent = "".join(sent)
    return sent

def read_data(file_dir, filename):
    data = []
    data_path = os.path.join(file_dir, filename)
    d = json.load(open(data_path, 'r'))
    for ins in d:
        sent = ins['sentence'].replace('\n', '').lower()
        label = label2id.get(ins['relation'], NA_id)
        tail_word = ins['tail']['word']
        head_word = ins['head']['word']
        data.append([sent, label,head_word,tail_word])
    random.shuffle(data)
    return data


def process_data(data, max_length):
    def pad(x):
        return x[:max_length] if len(x) > (max_length) else x + [0] * ((max_length) - len(x))
    # sent_raw = [x for x, _, _, _ in data]
    # labels = [y for _, y, _, _ in data]
    # head_word = [h for _, _, h, _ in data]
    # tail_word = [t for _, _, _, t in data]

    input_ids_pad=[]
    input_mask_data=[]
    input_segment_data=[]
    input_labels=[y for _, y, _, _ in data]
    e1_mask_data = []
    e2_mask_data = []
    for ins in tqdm(data):
        sent_ins = ins[0]

        head_word = ins[2]
        tail_word = ins[3]

        head_pos = sent_ins.index(head_word)
        head_pos = [head_pos, head_pos + len(head_word)]
        tail_pos = sent_ins.index(tail_word)
        tail_pos = [tail_pos, tail_pos + len(tail_word)]

        sent_ins = tokenizer_entity(sent_ins,head_pos,tail_pos)
        tokenized_text = tokenizer.tokenize(sent_ins)


        tokenized_text = ["CLS"]+tokenized_text



        e11_p = tokenized_text.index("<e1>")
        e12_p = tokenized_text.index("</e1>")
        e21_p = tokenized_text.index("<e2>")
        e22_p = tokenized_text.index("</e2>")

        tokenized_text[e11_p] = "$"
        tokenized_text[e12_p] = "$"
        tokenized_text[e21_p] = "#"
        tokenized_text[e11_p] = "#"




        input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)

        input_ids = pad(input_ids)

        input_mask = [1 if i != 0 else 0 for i in input_ids]
        input_segment = [0 for i in input_ids]

        # e1 mask, e2 mask
        e1_mask = [0] * len(input_mask)
        e2_mask = [0] * len(input_mask)

        # e1_mask和e2_mask
        for i in range(e11_p,e12_p+1):
            if i>len(e1_mask)-1:
                print(sent_ins)
                print(tokenized_text)
                print(i)
                print(len(e1_mask))
                exit()
            else:
                e1_mask[i] = 1
        for i in range(e21_p,e22_p+1):
            if i>len(e2_mask)-1:
                print(sent_ins)
                print(tokenized_text)
                print(i)
                print(len(e2_mask))
                exit()
            else:
                e2_mask[i] = 1

        input_ids_pad.append(input_ids)
        input_mask_data.append(input_mask)
        input_segment_data.append(input_segment)
        e1_mask_data.append(e1_mask)
        e2_mask_data.append(e2_mask)

    input_ids_pad = torch.tensor(input_ids_pad,dtype=torch.long)
    input_mask_data = torch.tensor(input_mask_data,dtype=torch.long)
    input_segment_data = torch.tensor(input_segment_data,dtype=torch.long)
    e1_mask_data = torch.tensor(e1_mask_data,dtype=torch.long)
    e2_mask_data = torch.tensor(e2_mask_data,dtype=torch.long)
    input_labels = torch.tensor(input_labels,dtype=torch.long)

    return input_ids_pad, input_mask_data, input_segment_data, input_labels, e1_mask_data, e2_mask_data

def get_dataset(file_dir, filename, max_length):
    data = read_data(file_dir,filename)
    input_ids_pad, input_mask, input_segment, input_labels, e1_mask_data, e2_mask_data = process_data(data, max_length)
    dataset = TensorDataset(input_ids_pad, input_mask,input_segment, input_labels, e1_mask_data, e2_mask_data)
    return dataset


def dump_dataset(data_name):
    dataset = get_dataset(file_dir='.', filename=data_name + '.json', max_length=config.MAX_LENGTH)
    torch.save(dataset, data_name + '.pt')


dump_dataset('train')
dump_dataset('valid')
dump_dataset('test')

Model code:

from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction,BertPreTrainedModel
from transformers import BertModel
import torch.nn as nn
import config
import torch


class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True):
        super(FCLayer, self).__init__()
        self.use_activation = use_activation
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, output_dim)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.dropout(x)
        if self.use_activation:
            x = self.tanh(x)

        return self.linear(x)

class Mybert(BertPreTrainedModel):
    def __init__(self, bert_config, args):
        super(Mybert, self).__init__(bert_config)
        self.bert = BertModel.from_pretrained(args.pretrained_model_name, config=bert_config)  # Load pretrained bert

        self.num_labels = bert_config.num_labels

        self.cls_fc_layer = FCLayer(bert_config.hidden_size, bert_config.hidden_size, args.dropout_rate)
        self.e1_fc_layer = FCLayer(bert_config.hidden_size, bert_config.hidden_size, args.dropout_rate)
        self.e2_fc_layer = FCLayer(bert_config.hidden_size, bert_config.hidden_size, args.dropout_rate)
        self.label_classifier = FCLayer(bert_config.hidden_size * 3, bert_config.num_labels, args.dropout_rate, use_activation=False)

    @staticmethod
    def entity_average(hidden_output, e_mask):
        """
        Average the entity hidden state vectors (H_i ~ H_j)
        :param hidden_output: [batch_size, j-i+1, dim]
        :param e_mask: [batch_size, max_seq_len]
                e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
        :return: [batch_size, dim]
        """
        e_mask_unsqueeze = e_mask.unsqueeze(1)  # [b, 1, j-i+1]
        length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)  # [batch_size, 1]

        sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)  # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
        avg_vector = sum_vector.float() / length_tensor.float()  # broadcasting
        return avg_vector

    def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        # Average
        e1_h = self.entity_average(sequence_output, e1_mask)
        e2_h = self.entity_average(sequence_output, e2_mask)

        # Dropout -> tanh -> fc_layer
        pooled_output = self.cls_fc_layer(pooled_output)
        e1_h = self.e1_fc_layer(e1_h)
        e2_h = self.e2_fc_layer(e2_h)

        # Concat -> fc_layer
        concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
        logits = self.label_classifier(concat_h)

        return logits

Besides the seq. length you should also check the inputs (min and max values) to the embedding layer, which contains often out-of-bounds indices.

Thank you for your reply. I solved this problem by adding the following code:

model.resize_token_embeddings(len(tokenizer))`

I hope it can help other people who meet the same problem with me. :grinning:

2 Likes