I am using BERT to classify text segments extracted from ASR-generated transcripts. There are multiple segments per participant. I have stored the data in a dataframe with the following column names: Participant_ID, Segment_Text and Diagnosis (i.e. label). I have successfully trained my dataset and am trying to identify which segments of text (i.e. Segment_Text) and Participant_ID the model misclassified - how would I go about doing this? I have provided the code for the custom Dataset class and Evaluation class (i.e. testing the BERT model) below:
Custom Dataset Class
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {'HC':0, 'PK':1}
class Dataset(torch.utils.data.Dataset):
def __init__(self, df):
self.participantIDs = df['Participant_ID']
self.labels = [labels[label] for label in df['Diagnosis']]
self.texts = [tokenizer(text, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") for text in df['Transcript_Segment']]
def classes(self):
return self.labels
def __len__(self):
return len(self.labels)
def get_batch_labels(self, idx):
# Fetch a batch of labels
return np.array(self.labels[idx])
def get_batch_texts(self, idx):
# Fetch a batch of inputs
return self.texts[idx]
# def get_batch_participant_ids(self, idx):
# # Fetch a batch of inputs
# return self.participantIDs[idx]
def __getitem__(self, idx):
batch_texts = self.get_batch_texts(idx)
batch_y = self.get_batch_labels(idx)
# batch_participant_ids = self.get_batch_participant_ids(idx)
return batch_texts, batch_y
Evaluate BERT model class
def evaluate(model, test_data):
incorrect_samples = []
test = Dataset(test_data)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=8)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
model = model.cuda()
total_acc_test = 0
with torch.no_grad():
for test_input, test_label in test_dataloader:
test_label = test_label.to(device)
mask = test_input['attention_mask'].to(device)
input_id = test_input['input_ids'].squeeze(1).to(device)
output = model(input_id, mask)
_, pred = torch.max(output,1)
idxs_mask = ((pred == test_label) == False).nonzero()
print(idxs_mask)
incorrect_samples.append(input_id[idxs_mask].cpu().detach().numpy())
acc = (output.argmax(dim=1) == test_label).sum().item()
total_acc_test += acc
print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')
print(incorrect_samples)