Process Batch Transformation on GPU ( expected device cuda:0 but got device cpu)

My dataset returns an audio file of fixed length and the generator batches them to 64 files. I want to preprocess the whole batch on the GPU to speed things up because individual audio processing is time consuming. When I call the below:

for x, y in training_generator:
    x = preprocess(x.to(device))

I get the error ‘expected device cuda:0 but got device cpu’

The preprocess function is below:

def STFT(batch, N_FFT, TOP_DB, chroma=False):
    window = torch.hann_window(N_FFT)
    STFT = torch.stft(batch, N_FFT, window=window)
    # STFT has 2 channels for real and imaginary we need to get the magnitude
    a_2 = torch.mul(STFT[:, :, :, 0], STFT[:, :, :, 0])
    b_2 = torch.mul(STFT[:, :, :, 1], STFT[:, :, :, 1])
    a_2_plus_b_2 = torch.add(a_2, b_2)

    STFT_Mag = torch.sqrt(a_2_plus_b_2)
    if chroma: return STFT_Mag
    STFT_DB = torchaudio.transforms.AmplitudeToDB('magnitude', top_db=TOP_DB)(STFT_Mag)
    return STFT_DB

def MELS(batch, MELSPECTOGRAM_PARAMETERS, TOP_DB):
    MEL_SPEC = torchaudio.transforms.MelSpectrogram(**MELSPECTOGRAM_PARAMETERS)(batch)
    return torchaudio.transforms.AmplitudeToDB('power', top_db=TOP_DB)(MEL_SPEC)

def MFCC(batch, N_MFCC, N_MELS, MEL_SPEC_DB):
    DCT = torchaudio.functional.create_dct(N_MFCC, N_MELS, 'ortho')
    return torch.matmul(MEL_SPEC_DB.transpose(1, 2), DCT).transpose(1, 2)

def CHROMA(batch, CFB_INIT, N_FFT, TOP_DB):
    STFT_DB = STFT(batch , N_FFT, TOP_DB, True).to(device)
    CFB = torch.zeros(STFT_DB.shape[0], CFB_INIT.shape[0], CFB_INIT.shape[1])
    CFB[:, :, :] = CFB_INIT
    
    raw_chroma = torch.bmm(CFB, STFT_DB)

    # Compute normalization factor for each frame
    return F.normalize(raw_chroma, p = float("Inf"), dim = 1, eps = 1e-12, out=None)

def TONNETZ(CHROMA):
    # Generate Transformation matrix
    dim_map = torch.linspace(0, 12, steps=CHROMA.shape[1])

    scale = torch.FloatTensor([7. / 6, 7. / 6,
                               3. / 2, 3. / 2,
                               2. / 3, 2. / 3])

    V = torch.ger(scale, dim_map)

    # # Even rows compute sin()
    V[::2] -= 0.5

    R = torch.FloatTensor([1, 1,
                           1, 1,
                           0.5, 0.5]).to(device) 

    phi = R[:, np.newaxis] * torch.cos(torch.pi * V)
    phi_s = torch.zeros(CHROMA.shape[0], phi.shape[0], phi.shape[1])
    phi_s[:, :, :] = phi

    # Do the transform to tonnetz
    return torch.bmm(phi_s, CHROMA)

def preprocess(batch, 
               frames=SAMPLING_RATE*WINDOW_SIZE,
               MELSPECTOGRAM_PARAMETERS_P=MELSPECTOGRAM_PARAMETERS,
               TOP_DB=TOP_DB,
               N_MFCC=N_MFCC,
               N_MELS=N_MELS,
               CFB_INIT = CFB_INIT):
    Final_Data = torch.zeros(batch.shape[0], 78, 313, 2).to(device)
    Final_Data[:, 0:60, :, 0]  = MELS(batch, MELSPECTOGRAM_PARAMETERS, TOP_DB)
    Final_Data[:, 0:60, :, 1]  = MFCC(batch, N_MFCC, N_MELS, Final_Data[:, 0:60, :, 0])
    Final_Data[:, 60:72, :, 0] = CHROMA(batch, CFB_INIT, N_FFT, TOP_DB)
    Final_Data[:, 60:72, :, 1] = Final_Data[:, 60:72, :, 0]
    Final_Data[:, 72:, :, 0]   = TONNETZ(Final_Data[:, 60:72, :, 0])
    Final_Data[:, 72:, :, 1]   = Final_Data[:, 72:, :, 0]
    return Final_Data

Do I have the wrong approach for this? When i was processing in the dataset each batch took forever to produce so I wanted to batch the transformation like above and process it on the GPU but I’m unable too

Which line of code is raising the error?
I guess that some internal torchaudio.transforms might create CPUTensors and might thus yield the device mismatch error.
We could take a look at it, if you could post the complete stack trace.