Reset a model parameter value regularly during training

I have created a model in the following way. The model mimics the straight-through gradient estimation as proposed in the VQ-VAE, with the difference that I reset the codewords parameter after every epoch for some training epochs.

class KMeansQuantizer(nn.Module):
    def __init__(self, num_codewords: int, dim: int, seglen: int, filespath: str, codewords_trainable: Optional[bool]=False):
        super().__init__()
        self.num_codewords = num_codewords
        self.dim = dim
        self.codewords_trainable = codewords_trainable
        if not self.codewords_trainable:
            self.db = (self.get_segs(filespath, seglen))
        self.codewords = nn.Parameter(torch.ones(num_codewords, dim)) #nn.Parameter(0.01*torch.randn(num_codewords, dim))

    @staticmethod
    def get_segs(filespath: str, seglen: int)-> torch.Tensor:
        f = open(filespath, "r")     
        fnames = f.readlines()
        db = []
        for fname in tqdm(fnames):
            audio_data = read_audio(fname.strip())
            chunks = F.unfold(audio_data[None, None, None, :], kernel_size=(1,seglen), stride=(1,seglen)).squeeze_(0).T
            db.append(chunks.unsqueeze_(1))
        return db

    def run_kmeans(self, model: List[nn.Module], niter:Optional[int]=100, verbose:Optional[bool]=True, seed:Optional[int]=100, device=Optional[str])-> None:
        model[0].eval()
        model[1].eval()
        Z_db = []
        with torch.no_grad():  
            print("preparing database for clustering")
            for batch in tqdm(self.db):
                ze = model[1](model[0](batch.to(device)))
                Z_db.append(ze.detach().cpu().numpy())
            Z_db = np.concatenate(Z_db).reshape(-1, self.dim)
            self.kmeans = faiss.Kmeans(self.dim, self.num_codewords, niter=niter, verbose=verbose, spherical=True)
            self.kmeans.seed = seed
            self.kmeans.train(Z_db)  
        self.codewords.data =  torch.from_numpy(self.kmeans.centroids).to(device).detach().clone() ######
        model[0].train()
        model[1].train()
        Z_db = None
        return 

    def _pass_grad(self, x: torch.Tensor, y:torch.Tensor)-> torch.Tensor:
        """
        performs gradients bypassing
        for y = f(x), where f is non differentiable, we set dL/dy = dL/dx
        """
        return y.detach() + (x-x.detach())
    
    @staticmethod
    def _get_cosine_disimilarity(x,y):
        cosine_dist = 1 - torch.einsum('bd, bd -> b', [x.view(-1, x.shape[-1]), y.view(-1, y.shape[-1])])
        return cosine_dist.mean()
        
    def forward(self, ze):
        results = {}
        if self.codewords_trainable:
            codewords = self.codewords/torch.linalg.norm(self.codewords, dim=-1).view(-1,1)
        else:
            codewords = self.codewords
        results["codewords"] = codewords
        
        nearest_clust_idx = torch.argmax(torch.einsum("btd,kd -> btk", [ze, codewords]), dim=-1)
        zq = codewords[nearest_clust_idx]
        x = self._pass_grad(ze, zq)
        results["idx"] = nearest_clust_idx

        if self.codewords_trainable:
            results["latent_loss"] = self._get_cosine_disimilarity(zq, ze.detach())
            results["commitment_loss"] = self._get_cosine_disimilarity(ze, zq.detach())
        

I want to reset the parameter value before a training epoch starts. It will get updated with every training step, but I’d like to reset its values for the next training epoch. I want to reiterate this process for some training epochs only.

The way I do it right now is to reset its data value (in the run_kmeans method) as mentioned hereunder, but I get its gradient to be None:

self.codewords.data = torch.from_numpy(self.kmeans.centroids).to(device).detach().clone()

Could someone please let me know what is the correct implementation?

Hi Anup!

The code you posted shows the forward() method not returning anything.
You also don’t show what you are actually using as a loss criterion and how
it is used in a .backward() call. So it’s hard to tell what might be going on.

Note that the result of torch.argmax() is a LongTensor, so you can’t
backpropagate through that. Also, although zq would appear to be
connected to self.codewords by the computation graph, the zq.detach()
detaches the "commitment_loss" piece of results from the computation
graph. Last, it looks like the result of x = self._pass_grad(ze, zq) isn’t
used.

If none of these observations help you sort out your problem, please post a
simplified, fully-self-contained, runnable script that reproduces your issue,
together with its output.

Best.

K. Frank