HugginFace dataset error: RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Hello.
I have taken code from many sources regarding Common Voice dataset. The only modifications I did was to change the language from Turkish to Persian.
I try to run the codes. However, I encounter this error when the line trainer.train() runs:

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

This is my code. You can copy paste it in Google Colab and run it (where I have run):

!pip install datasets==1.13.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.0.19
!pip install torchaudio
!pip install librosa
!pip install jiwer
!apt install git-lfs
!pip install hazm
!pip install pydub
!pip install pythainlp
import os
import re
#from typing import List, Dict, Tuple
import pandas as pd
from scipy.io import wavfile
from pythainlp.tokenize import word_tokenize
#from spell_correction import correct_sentence
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from pydub import AudioSegment
from pythainlp.tokenize import word_tokenize, syllable_tokenize
from datasets import load_dataset, load_from_disk, load_metric
import hazm
import string
import torch
import os
#os.environ['CUDA_VISIBLE_DEVICES']='2, 3'

torch.cuda.empty_cache()

#print(torch.cuda.memory_summary(device=None, abbreviated=False))

print(torch.cuda.is_available())




_normalizer = hazm.Normalizer()

chars_to_ignore = [
    ",", "?", ".", "!", "-", ";", ":", '""', "%", "'", '"', "�",
    "#", "!", "؟", "?", "«", "»", "،", "(", ")", "؛", "'ٔ", "٬",'ٔ', ",", "?", 
    ".", "!", "-", ";", ":",'"',"“", "%", "‘", "”", "�", "–", "…", "_", "”", '“', '„',
    'ā', 'š',
#     "ء",
]

# In case of farsi
chars_to_ignore = chars_to_ignore + list(string.ascii_lowercase + string.digits)

chars_to_mapping = {
    'ك': 'ک', 'دِ': 'د', 'بِ': 'ب', 'زِ': 'ز', 'ذِ': 'ذ', 'شِ': 'ش', 'سِ': 'س', 'ى': 'ی',
    'ي': 'ی', 'أ': 'ا', 'ؤ': 'و', "ے": "ی", "ۀ": "ه", "ﭘ": "پ", "ﮐ": "ک", "ﯽ": "ی",
    "ﺎ": "ا", "ﺑ": "ب", "ﺘ": "ت", "ﺧ": "خ", "ﺩ": "د", "ﺱ": "س", "ﻀ": "ض", "ﻌ": "ع",
    "ﻟ": "ل", "ﻡ": "م", "ﻢ": "م", "ﻪ": "ه", "ﻮ": "و", 'ﺍ': "ا", 'ة': "ه",
    'ﯾ': "ی", 'ﯿ': "ی", 'ﺒ': "ب", 'ﺖ': "ت", 'ﺪ': "د", 'ﺮ': "ر", 'ﺴ': "س", 'ﺷ': "ش",
    'ﺸ': "ش", 'ﻋ': "ع", 'ﻤ': "م", 'ﻥ': "ن", 'ﻧ': "ن", 'ﻭ': "و", 'ﺭ': "ر", "ﮔ": "گ",
        
    # "ها": "  ها", "ئ": "ی",
    "۱۴ام": "۱۴ ام",
        
    "a": " ای ", "b": " بی ", "c": " سی ", "d": " دی ", "e": " ایی ", "f": " اف ",
    "g": " جی ", "h": " اچ ", "i": " آی ", "j": " جی ", "k": " کی ", "l": " ال ",
    "m": " ام ", "n": " ان ", "o": " او ", "p": " پی ", "q": " کیو ", "r": " آر ",
    "s": " اس ", "t": " تی ", "u": " یو ", "v": " وی ", "w": " دبلیو ", "x": " اکس ",
    "y": " وای ", "z": " زد ",
    "\u200c": " ", "\u200d": " ", "\u200e": " ", "\u200f": " ", "\ufeff": " ",
}


def multiple_replace(text, chars_to_mapping):
    pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
    return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))

def remove_special_characters(text, chars_to_ignore_regex):
    text = re.sub(chars_to_ignore_regex, '', text).lower() + " "
    return text

def normalizer(text, chars_to_ignore=chars_to_ignore, chars_to_mapping=chars_to_mapping):
    chars_to_ignore_regex = f"""[{"".join(chars_to_ignore)}]"""
    text = text.lower().strip()

    text = _normalizer.normalize(text)
    text = multiple_replace(text, chars_to_mapping)
    text = remove_special_characters(text, chars_to_ignore_regex)
    text = re.sub(" +", " ", text)
    _text = []
    for word in text.split():
        try:
            word = int(word)
            _text.append(words(word))
        except:
            _text.append(word)
            
    text = " ".join(_text) + " "
    text = text.strip()

    if not len(text) > 0:
        return None
    
    return text + " "

data_dir = "cv-corpus-9.0-2022-04-27/fa"

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("common_voice", "fa", split="train")
common_voice_train = common_voice_train.select(range(500))
common_voice_dev = load_dataset("common_voice", "fa", split="validation")
common_voice_dev = common_voice_dev.select(range(50))
common_voice_test = load_dataset("common_voice", "fa", split="test")
common_voice_test = common_voice_test.select(range(50))

print(common_voice_train)
print(common_voice_dev)
print(common_voice_test)

from datasets import ClassLabel
import random
import pandas as pd
#from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    print(df.head())

#show_random_elements(common_voice_train.remove_columns(["path"]), num_examples=20)

def normalizer(batch, chars_to_ignore=chars_to_ignore, chars_to_mapping=chars_to_mapping):
    chars_to_ignore_regex = f"""[{"".join(chars_to_ignore)}]"""
    text = batch["sentence"].lower().strip()

    text = _normalizer.normalize(text)
    text = multiple_replace(text, chars_to_mapping)
    text = remove_special_characters(text, chars_to_ignore_regex)
    text = re.sub(" +", " ", text)
    _text = []
    for word in text.split():
        try:
            word = int(word)
            _text.append(words(word))
        except:
            _text.append(word)
            
    text = " ".join(_text) + " "
    text = text.strip()

    if not len(text) > 0:
        return None

    if len(text) >= 32:
        text = text[:30]
    
    batch["sentence"] = text
    
    return batch

#print(common_voice_train[0]["sentence"])
#print(common_voice_dev[0]["sentence"])
#print(common_voice_test[0]["sentence"])

common_voice_train = common_voice_train.map(normalizer, fn_kwargs={"chars_to_ignore": chars_to_ignore, "chars_to_mapping": chars_to_mapping})
common_voice_dev = common_voice_dev.map(normalizer, fn_kwargs={"chars_to_ignore": chars_to_ignore, "chars_to_mapping": chars_to_mapping})
common_voice_test = common_voice_test.map(normalizer, fn_kwargs={"chars_to_ignore": chars_to_ignore, "chars_to_mapping": chars_to_mapping})

#print(common_voice_train[0]["sentence"])
#print(common_voice_dev[0]["sentence"])
#print(common_voice_test[0]["sentence"])

def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=4, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_dev = common_voice_dev.map(extract_all_chars, batched=True, batch_size=4, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=4, keep_in_memory=True, remove_columns=common_voice_test.column_names)

vocab_list = list(sorted(set(vocab_train["vocab"][0]) | set(vocab_dev["vocab"][0]) | set(vocab_test["vocab"][0])))
vocab_list = [vocab for vocab in vocab_list if vocab not in [" ", "\u0307"]]
print(len(vocab_list))
print(vocab_list)

vocab_list = list(sorted(set(vocab_train["vocab"][0]) | set(vocab_dev["vocab"][0]) | set(vocab_test["vocab"][0])))
vocab_list = [vocab for vocab in vocab_list if vocab not in [" ", "\u0307"]]
print(len(vocab_list))
print(vocab_list)

special_vocab = ["<pad>", "<s>", "</s>", "<unk>", "|"]
vocab_dict = {v: k for k, v in enumerate(special_vocab + vocab_list)}
print(len(vocab_dict))
print(vocab_dict)

for name, age in vocab_dict.items():  # for name, age in dictionary.iteritems():  (for Python 2.x)
    if age == 5:
        k1 = name
    elif age == 8:
        k2=name

del vocab_dict[k1]
del vocab_dict[k2]

import json
with open('vocab.json', 'w') as vocab_file:
	json.dump(vocab_dict, vocab_file)

from transformers.trainer_utils import get_last_checkpoint

save_dir = "model checkpoints/"

last_checkpoint = None
if os.path.exists(save_dir):
    last_checkpoint = get_last_checkpoint(save_dir)

print(last_checkpoint if last_checkpoint else str(None))

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer(
        "vocab.json", 
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        word_delimiter_token="|",
        do_lower_case=False,
        max_length=31
    )

text = "از مهمونداری کنار بکشم"
print(" ".join(tokenizer.tokenize(text)))
print(tokenizer.decode(tokenizer.encode(text)))

from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

if len(processor.tokenizer.get_vocab()) == len(processor.tokenizer):
	print(len(processor.tokenizer))

if not os.path.exists(save_dir):
	print("Saving ...")
	processor.save_pretrained(save_dir)
	print("Saved!")

import torchaudio
import librosa


target_sampling_rate = 16_000

def speech_file_to_array_fn(batch):
	speech_array, sampling_rate = torchaudio.load(batch["path"])
	speech_array = speech_array.squeeze().numpy()
	speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, target_sampling_rate)

	batch["speech"] = speech_array
	batch["sampling_rate"] = target_sampling_rate
	batch["duration_in_seconds"] = len(batch["speech"]) / target_sampling_rate
	batch["target_text"] = batch["sentence"]
	return batch

common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names)
common_voice_dev = common_voice_dev.map(speech_file_to_array_fn, remove_columns=common_voice_dev.column_names)
common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)

#print(common_voice_train[0]["sampling_rate"])
#print(common_voice_test[0]["sampling_rate"])


min_duration_in_seconds = 5.0
max_duration_in_seconds = 10.0

def filter_by_max_duration(batch):
	return min_duration_in_seconds <= batch["duration_in_seconds"] <= max_duration_in_seconds

print(f"Split sizes [BEFORE]: {len(common_voice_train)} train and {len(common_voice_test)} validation.")



_common_voice_train = common_voice_train.filter(filter_by_max_duration)
_common_voice_dev = common_voice_dev
_common_voice_test = common_voice_test
# _common_voice_test = common_voice_test.filter(filter_by_max_duration, num_proc=4)

print(f"Split sizes [AFTER]: {len(_common_voice_train)} train and {len(_common_voice_test)} validation.")

# check that all files have the correct sampling rate
def prepare_dataset(batch):
	assert (
        len(set(batch["sampling_rate"])) == 1), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

	batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

	with processor.as_target_processor():
		batch["labels"] = processor(batch["target_text"]).input_ids

	return batch

_common_voice_train = _common_voice_train.map(prepare_dataset, remove_columns=_common_voice_train.column_names, batch_size=4, batched=True)
_common_voice_dev = _common_voice_dev.map(prepare_dataset, remove_columns=_common_voice_dev.column_names, batch_size=4, batched=True)
_common_voice_test = _common_voice_test.map(prepare_dataset, remove_columns=_common_voice_test.column_names, batch_size=4, batched=True)

#_common_voice_train.set_format(type='torch', columns=['input_values', 'labels'])
#_common_voice_dev.set_format(type='torch', columns=['input_values', 'labels'])
#_common_voice_test.set_format(type='torch', columns=['input_values', 'labels'])


###############################################################################################################

#torch.cuda.empty_cache()

#print(torch.cuda.memory_summary(device=None, abbreviated=False))

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
	processor: Wav2Vec2Processor
	padding: Union[bool, str] = True
	max_length: Optional[int] = None
	max_length_labels: Optional[int] = None
	pad_to_multiple_of: Optional[int] = None
	pad_to_multiple_of_labels: Optional[int] = None

	def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
		input_features = [{"input_values": feature["input_values"]} for feature in features]
		label_features = [{"input_ids": feature["labels"]} for feature in features]

		batch = self.processor.pad(
			input_features,
			padding=self.padding,
			max_length=self.max_length,
			pad_to_multiple_of=self.pad_to_multiple_of,
			return_tensors="pt",
			)
		with self.processor.as_target_processor():
			labels_batch = self.processor.pad(
				label_features,
				padding=self.padding,
				max_length=self.max_length_labels,
                #max_length=64,
				pad_to_multiple_of=self.pad_to_multiple_of_labels,
				return_tensors="pt",
				)

		labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

		batch["labels"] = labels

		return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

wer_metric = load_metric("wer")

import random


def compute_metrics(pred):
	pred_logits = pred.predictions
	pred_ids = np.argmax(pred_logits, axis=-1)

	pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

	pred_str = processor.batch_decode(pred_ids)

	label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

	if isinstance(label_str, list):
		if isinstance(pred_str, list) and len(pred_str) == len(label_str):
			for index in random.sample(range(len(label_str)), 3):
				print(f'reference: "{label_str[index]}"')
				print(f'predicted: "{pred_str[index]}"')
		else:
			for index in random.sample(range(len(label_str)), 3):
				print(f'reference: "{label_str[index]}"')
				print(f'predicted: "{pred_str}"')

	wer = wer_metric.compute(predictions=pred_str, references=label_str)

	return {"wer": wer}


from transformers import Wav2Vec2ForCTC, Wav2Vec2Config

configuration = Wav2Vec2Config(hidden_size=256, num_hidden_layers=6, num_attention_heads=6, intermediate_size=1024)

model_args ={}

print('haaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')
print(len(processor.tokenizer.get_vocab()))
print('haaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa')

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53" if not last_checkpoint else last_checkpoint, 
    #model_name_or_path if not last_checkpoint else last_checkpoint,
    attention_dropout=0.1,
    #hidden_size=256,
    #num_hidden_layers=8,
    #num_attention_heads=2,
    #intermediate_size=256,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    ctc_zero_infinity=True,
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer.get_vocab())
    #vocab_size=64
)

model.config = configuration


model.freeze_feature_extractor()

model = model.to(torch.device("cuda"))

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=save_dir,
    group_by_length=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=0.5,
    fp16=True,
    #save_steps=10,
    #eval_steps=10,
    #logging_steps=10,
    learning_rate=1e-4,
    #warmup_steps=500,
    #save_total_limit=2,
)

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=_common_voice_train,
    eval_dataset=_common_voice_test,
    tokenizer=processor.feature_extractor,
)

torch.cuda.empty_cache()

train_result = trainer.train()


metrics = train_result.metrics
max_train_samples = len(_common_voice_train)
metrics["train_samples"] = min(max_train_samples, len(_common_voice_train))

trainer.save_model()

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

I’d be really thankful for anyone who can solve my problem. Please help me as this is driving me mad.
P.S.: There’s a line which tries to set the format of _common_dataset to torch tensor. However, even when I run it, i still encounter errors like I mentioned.

Based on the error message it seems you’ve pushed the model to the GPU, but not the inputs.
I don’t know what the Trainer class is doing, but would guess that Trainer.train should be responsible to push the data to the GPU, too.

This is my guess that the data is on cpu and the model is on gpu. However, i really don’t know how to push huggingface arrow dataset to gpu. I even tried that “DataCollatorCTCWithPadding” class and pushed the batch to cuda. But i receive another error.
I am not so fluent in pytorch and i have seen you help people a lot. Please, this is very important for me but i really don’t know how to fix it. Please help me fix it. Thank you very much.

I’m unfortunately not familiar enough with the internals of the used classes, but think you might want to cross-post the question in the HuggingFace discussion board as the developers would certainly know what causes the error.

I managed to solve the issue. All i had to do was not to bring data and model to Cuda. There seems to be a problem with HuggingFace.
Thank you anyway.