I am loading and exporting a model to a pt
file from Hugging Face.
However, the output seems to differ between the exported model (a TorchScript) and the Pytorch model.
What could be causing the problem?
Export Script:
import torch
from speechbrain.pretrained.interfaces import Pretrained
class Encoder(Pretrained):
MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, wavs, wav_lens=None, normalize=False):
return self.encode_batch(wavs, wav_lens, normalize)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
feats = self.mods.compute_features(wavs)
feats = self.mods.mean_var_norm(feats, wav_lens)
embeddings = self.mods.embedding_model(feats, wav_lens)
if normalize:
embeddings = self.hparams.mean_var_norm_emb(
embeddings, torch.ones(embeddings.shape[0], device=self.device)
)
return embeddings
classifier = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
classifier.eval()
sample_wavs = torch.randn(1, 16000) # assuming 1 second of audio with 16kHz sample rate
input_dict = {
"forward": (sample_wavs,),
}
scripted_model = torch.jit.trace_module(classifier, inputs=input_dict)
scripted_model.eval()
scripted_model.save("ECAPA-TDNN-VOX2.pt")
Test Script:
import torch
import torchaudio
from export_yang_wang_tdnn import Encoder
def calculate_similarity_score(embs1, embs2):
# Compute similarity as before
X = embs1 / torch.linalg.norm(embs1)
Y = embs2 / torch.linalg.norm(embs2)
# Score
similarity_score = torch.dot(X, Y) / ((torch.dot(X, X) * torch.dot(Y, Y)) ** 0.5)
similarity_score = (similarity_score + 1) / 2
# Decision
if similarity_score >= 0.7:
print(" two audio files are from same speaker")
else:
print(" two audio files are from different speakers")
print(f"Similarity Score 2: {similarity_score}")
def get_embedding(audio_file, model):
audio, _ = torchaudio.load(audio_file)
embs = model(audio)
return embs
with torch.no_grad():
model_path = "ECAPA-TDNN-VOX2.pt"
exported_model = torch.jit.load(model_path, map_location=torch.device("cpu"))
exported_model.eval()
embs1 = get_embedding("1.wav", exported_model).squeeze()
embs2 = get_embedding("2.wav", exported_model).squeeze()
print("\n\n\n")
print(embs1)
print("\n\n\n")
calculate_similarity_score(embs1, embs2)
classifier = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
classifier.eval()
embs1 = get_embedding("1.wav", classifier).squeeze()
embs2 = get_embedding("2.wav", classifier).squeeze()
print("\n\n\n")
print(embs1)
print("\n\n\n")
calculate_similarity_score(embs1, embs2)
Output:
tensor([ 17.4265, 18.4827, -6.5524, -15.7520, 8.8310, 22.0934, 8.2339,
-32.3875, -6.8510, -12.3607, -13.9849, -0.7548, -9.6365, -1.0562,
8.6763, -3.4279, 11.0967, -1.8245, 7.0722, -9.3896, 10.9931,
-17.9823, 14.5675, -28.1516, 20.2218, 21.6082, -17.8120, 11.6475,
4.9818, 17.7021, 15.5902, -5.9433, 6.0047, 31.9569, 22.4159,
-5.7942, 0.1527, 3.0175, 13.3096, -14.3408, 8.4869, -6.1609,
4.4746, 10.4516, 11.5534, -5.6418, -3.7477, -8.5428, -10.6561,
10.8571, -0.8358, -17.7481, 17.6726, -7.8003, 13.5207, -4.7161,
6.7001, 6.1731, -9.6309, -24.5352, 12.2250, 5.2186, 28.9895,
-22.8382, 2.5595, 2.2920, 10.0660, -12.2751, -4.0394, -7.0524,
-19.2292, -18.1063, 10.5839, -3.6522, -9.6226, -7.5372, -6.7760,
1.9211, 23.8775, 3.0158, 14.5255, 4.2744, -8.1205, -4.2562,
-15.9318, 10.3941, -23.2881, 7.7236, -6.5062, 0.2158, -0.8689,
-19.4896, -9.6370, 21.9226, -2.7052, 27.3228, 3.1160, -4.0933,
8.4077, -8.1299, -3.8143, -6.4555, -1.1031, 4.5192, -37.7678,
-6.4635, -16.3251, -3.7136, 10.3487, 5.7073, -13.1537, 3.4378,
-9.6158, 20.4664, 8.5428, 11.4182, 1.5992, 10.0996, -14.3282,
7.2150, 12.3470, -3.6431, -23.9484, 3.5467, 1.8392, -27.1732,
13.8923, 4.3795, -1.0498, -8.0016, 2.3717, 2.4841, -13.7343,
-5.5505, -7.9340, -18.2014, -0.6432, -1.1195, -8.1046, -19.6473,
3.3592, 2.0802, 5.1738, -2.9291, -1.1100, -15.7681, 20.3292,
4.2611, -4.6319, 12.7996, -2.6277, 14.3068, -11.4737, -9.5859,
-17.4647, 19.6631, -0.1224, 21.5704, -5.0022, 22.3308, 19.6326,
-6.7660, -2.2318, 3.4402, 19.4086, -18.6844, 5.7403, -15.3297,
4.5604, 5.5215, 3.1186, -26.0653, -6.8775, 23.6961, 7.7369,
-7.3478, -6.0051, -5.4176, -10.1827, -27.6532, -4.3664, 3.7130,
6.5594, -8.6329, -5.5614, 13.5024, 19.7162, 10.0846, -7.0208,
-0.8902, -9.8562, 12.0251])
two audio files are from same speaker
Similarity Score 2: 0.9412376284599304
tensor([ 2.7420e+00, 1.1649e+01, -8.8248e+00, 1.9869e+01, -1.3088e+01,
-1.0886e+01, -1.6435e+01, 4.6483e+00, -3.9572e+00, 4.4734e+00,
1.2895e+01, 4.4200e+00, 3.4495e+00, -1.5029e+00, 1.2837e+01,
-4.7832e+00, -4.3518e+00, -1.6307e+01, 1.1015e+01, 1.8744e+01,
1.0738e+01, -2.2187e+00, 2.6528e+01, 1.1487e+01, 7.4944e+00,
9.3273e+00, -1.2424e+01, 1.6159e+01, -5.0016e+00, -9.5605e+00,
-1.5786e+00, -7.9519e+00, -3.1426e-01, -9.8059e+00, 1.2994e+01,
-2.9743e+00, -1.8329e+01, 8.8164e+00, 2.0401e+01, 2.3679e-02,
7.7053e+00, -4.3322e+00, 1.5231e+01, -2.3924e+00, 5.4399e+00,
-3.3659e+00, -8.3692e+00, -1.1856e+00, -3.4969e+00, 9.8103e+00,
-1.6941e+00, 1.1031e+00, 9.5047e+00, -1.4897e+01, 2.3147e+00,
-1.0449e+01, 8.7767e-01, -1.0616e+01, 1.7602e+00, 6.5198e+00,
1.7019e+01, 9.6794e+00, 3.1800e-01, 5.7724e-01, -1.5201e+01,
1.7264e+00, -2.5351e+00, -1.3069e+01, 9.8878e+00, -2.9789e+01,
7.5117e+00, -5.4878e+00, 4.3513e+00, 2.3655e+00, -9.4151e+00,
-1.0562e+01, 7.4361e+00, 3.8250e+00, 1.3992e+01, 5.8453e-01,
-5.2812e+00, 1.4257e+01, 1.3429e+01, 6.0729e+00, 5.1320e+00,
1.5210e+01, -1.4795e+01, 8.5817e+00, 6.6284e+00, 1.3744e+01,
-1.3318e+01, -1.6463e+01, -7.6232e-01, 1.6622e+01, 6.3580e+00,
1.2637e+01, 1.4080e+01, -6.9219e+00, -5.2070e+00, -1.9272e+00,
8.5520e+00, 7.1814e+00, -5.7860e+00, -1.4527e+00, -3.3659e+00,
4.3329e+00, -4.3502e+00, -6.5604e+00, 8.8280e-01, 2.4577e+00,
-5.9011e-01, 9.0167e+00, -2.6019e+00, 1.1001e-01, 2.0047e-01,
-4.5963e-01, 7.5912e+00, 7.4606e+00, -1.3943e+01, -4.9876e+00,
8.9396e+00, 4.1880e+00, -1.9634e+01, -2.2300e+01, 2.4642e+00,
-1.4048e+00, 1.7877e+01, 3.2127e+00, 1.3258e+01, -2.0172e+00,
2.6299e+00, -7.4409e+00, -7.0494e+00, -5.6323e+00, -5.1883e+00,
1.1370e+01, -1.4824e+01, -4.4420e+00, 5.6955e-01, 1.6458e+01,
1.1723e+01, 9.0847e+00, 3.3529e+00, -6.2683e+00, 1.0708e+01,
9.5542e-01, 2.6537e-01, 1.4606e+01, -1.1009e+00, -5.6804e-01,
-4.9638e+00, 3.1467e+00, -1.6994e+01, 5.1069e+00, 1.9925e+01,
-1.0569e+01, -9.9650e-01, 1.2112e+01, -2.6073e+00, 6.1138e+00,
6.1991e+00, -2.0184e+00, 1.0387e+01, -7.6137e+00, -6.7222e+00,
4.6268e-01, 3.1711e+00, -3.3232e+00, -6.8194e+00, -1.4877e+01,
1.0600e+01, -1.0826e+01, -1.2478e+00, 2.1171e+01, -2.4119e+00,
1.4314e+00, 1.2130e+00, 6.6838e+00, -5.6456e+00, -1.3453e+01,
-1.1070e+01, -7.9374e+00, -7.0325e+00, 2.7956e+00, 4.8391e+00,
8.0030e+00, 4.2917e+00, 1.3327e+00, 8.1383e-01, -7.7327e+00,
-2.2210e-01, 6.8660e+00])
two audio files are from different speakers
Similarity Score 2: 0.5776064395904541