I have a loss function called SI-SNR loss function implemented as follows:
def si_snr_loss(ests, egs):
refs = egs["ref"]
num_spks = len(refs)
def sisnr_loss(permute):
return sum([sisnr(ests[s], refs[t]) for s, t in enumerate(permute)]) / len(permute)
N = egs["mix"].size(0)
sisnr_mat = torch.stack([sisnr_loss(p) for p in permutations(range(num_spks))])
max_perutt,_ = torch.max(sisnr_mat, dim=0)
return -torch.sum(max_perutt) / N
I want to replace this loss function with an MSE loss function I write a simple equation as follows:
for egs in val_dataloader:
current_step += 1
egs = to_device(egs, self.device)
ests = data_parallel(self.net, egs['mix'], device_ids=self.gpuid)
#loss = si_snr_loss(ests, egs)
loss = (ests - torch.Tensor(np.array(egs.values())))**2
losses.append(loss.item())
Unfortunately, this gives me an error:
TypeError: can’t convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, int64, int32, int16, int8, uint8, and bool.