I have created a model in the following way. The model mimics the straight-through gradient estimation as proposed in the VQ-VAE, with the difference that I reset the codewords
parameter after every epoch for some training epochs.
class KMeansQuantizer(nn.Module):
def __init__(self, num_codewords: int, dim: int, seglen: int, filespath: str, codewords_trainable: Optional[bool]=False):
super().__init__()
self.num_codewords = num_codewords
self.dim = dim
self.codewords_trainable = codewords_trainable
if not self.codewords_trainable:
self.db = (self.get_segs(filespath, seglen))
self.codewords = nn.Parameter(torch.ones(num_codewords, dim)) #nn.Parameter(0.01*torch.randn(num_codewords, dim))
@staticmethod
def get_segs(filespath: str, seglen: int)-> torch.Tensor:
f = open(filespath, "r")
fnames = f.readlines()
db = []
for fname in tqdm(fnames):
audio_data = read_audio(fname.strip())
chunks = F.unfold(audio_data[None, None, None, :], kernel_size=(1,seglen), stride=(1,seglen)).squeeze_(0).T
db.append(chunks.unsqueeze_(1))
return db
def run_kmeans(self, model: List[nn.Module], niter:Optional[int]=100, verbose:Optional[bool]=True, seed:Optional[int]=100, device=Optional[str])-> None:
model[0].eval()
model[1].eval()
Z_db = []
with torch.no_grad():
print("preparing database for clustering")
for batch in tqdm(self.db):
ze = model[1](model[0](batch.to(device)))
Z_db.append(ze.detach().cpu().numpy())
Z_db = np.concatenate(Z_db).reshape(-1, self.dim)
self.kmeans = faiss.Kmeans(self.dim, self.num_codewords, niter=niter, verbose=verbose, spherical=True)
self.kmeans.seed = seed
self.kmeans.train(Z_db)
self.codewords.data = torch.from_numpy(self.kmeans.centroids).to(device).detach().clone() ######
model[0].train()
model[1].train()
Z_db = None
return
def _pass_grad(self, x: torch.Tensor, y:torch.Tensor)-> torch.Tensor:
"""
performs gradients bypassing
for y = f(x), where f is non differentiable, we set dL/dy = dL/dx
"""
return y.detach() + (x-x.detach())
@staticmethod
def _get_cosine_disimilarity(x,y):
cosine_dist = 1 - torch.einsum('bd, bd -> b', [x.view(-1, x.shape[-1]), y.view(-1, y.shape[-1])])
return cosine_dist.mean()
def forward(self, ze):
results = {}
if self.codewords_trainable:
codewords = self.codewords/torch.linalg.norm(self.codewords, dim=-1).view(-1,1)
else:
codewords = self.codewords
results["codewords"] = codewords
nearest_clust_idx = torch.argmax(torch.einsum("btd,kd -> btk", [ze, codewords]), dim=-1)
zq = codewords[nearest_clust_idx]
x = self._pass_grad(ze, zq)
results["idx"] = nearest_clust_idx
if self.codewords_trainable:
results["latent_loss"] = self._get_cosine_disimilarity(zq, ze.detach())
results["commitment_loss"] = self._get_cosine_disimilarity(ze, zq.detach())
I want to reset the parameter value before a training epoch starts. It will get updated with every training step, but I’d like to reset its values for the next training epoch. I want to reiterate this process for some training epochs only.
The way I do it right now is to reset its data value (in the run_kmeans
method) as mentioned hereunder, but I get its gradient to be None:
self.codewords.data = torch.from_numpy(self.kmeans.centroids).to(device).detach().clone()
Could someone please let me know what is the correct implementation?