Feedback on manually implemented hierarchical softmax

I manually implemented the hierarchical softmax, since I did not find its implementation. I implemented my model as follows. The model is simple word2vec model, but instead of using negative sampling, I want to use hierarchical softmax. In hierarchical softmax, there is no output word representations like the ones used in vanila softmax, or negative sampling. In hierarchical softmax, we have a target word, and a binary tree with each leaf node corresponding to words in the vocabulary. I have one question and wanted feedback on the correctness of my approach. Thank you in advance. Following snippet correspond to implementation of binary tree for hierarchical softmax. I use huffman tree as a binary tree for softmax.

import numpy as np
import torch.nn as nn
import torch
import heapq

class Node:
    def __init__(self, token, freq):
        self.vec = torch.randn(300, requires_grad=True, dtype=torch.float)
        self.token = token
        self.freq = freq
        self.left = None
        self.right = None
        
    def __lt__(self, other):
        return self.freq < other.freq
    
    def __gt__(self, other):
        return self.freq > other.freq
    
    def __eq__(self, other):
        if(other == None):
            return False
        if(not isinstance(other, Node)):
            return False
        return self.freq == other.freq

class HuffmanTree:
    def __init__(self):
        self.heap = []
        self.codes = {}
        self.reverse_mapping = {}
        self.root = None
        
    def make_heap(self, frequency):
        for key in frequency:
            node = Node(key, frequency[key])
            heapq.heappush(self.heap, node)
            
    def merge_nodes(self):
        while(len(self.heap)>1):
            node1 = heapq.heappop(self.heap)
            node2 = heapq.heappop(self.heap)
            
            merged = Node(None, node1.freq + node2.freq)
            merged.left = node1
            merged.right = node2
            heapq.heappush(self.heap, merged)
            
    def make_codes_helper(self, root, current_code):
        if(root==None):
            return
        if(root.token != None):
            self.codes[root.token] = current_code
            self.reverse_mapping[current_code] = root.token
            return
        
        self.make_codes_helper(root.left, current_code + "0")
        self.make_codes_helper(root.right, current_code + "1")
        
    def make_codes(self):
        root = heapq.heappop(self.heap)
        self.root = root
        current_code = ""
        self.make_codes_helper(root, current_code)

Above snippet is to construct the binary hauffman tree for softmax. Then I initialise the leaves of the binary tree as words in my vocabulary as follows.

import pickle
d = pickle.load(open("token_freqs.p", "rb"))

# d1 is a dictionary with tokens as keys and their frequencies as values
d1 = {k:v for k,v in sorted(d.items(), key=lambda x: x[1], reverse=True)}

# ht corresponds to one instance of a huffman tree for softmax
ht = HuffmanTree()
ht.make_heap(d1)
ht.merge_nodes()
ht.make_codes()

Now I define the calculation of loss function for hierarchical softmax as follows :

def cal_loss(h, target):
    path_to_word = ht.codes[target]
    loss = torch.zeros(1, requires_grad=True, dtype=torch.float)
    root=ht.root
    for i in path_to_word:
        if(i=='0'):
            loss = loss +  torch.log(torch.sigmoid(torch.dot(root.vec, h)))
            root = root.left
        else:
            loss = loss +  torch.log(torch.sigmoid(-1*torch.dot(root.vec, h)))
            root = root.right
    loss = loss*-1
    return loss

And I define my word2vec model class as follows :

vocab_size=30000
class Word2Vec(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layer = nn.Linear(vocab_size, 300, bias=False)
        
    def forward(self, inp):
        hidd = self.embedding_layer(inp)
        return hidd
    
model = Word2Vec()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for word, target in tranning_set:
    hidd = model(word)
    loss = cal_loss(hidd, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Here my two questions are :
(1) As we can see, I’ve added model.parameters() in the optimizer I defined. I doubt that this just includes the parameters of the model defined in Word2Vec class. How do I ensure that the parameters corresponding to the binary tree built for softmax approximation also get updated. I know that we can easily get the model parameters using model.parameters(), but for the binary tree class I implemented, we can not do anything like that. How do I ensure that my model updates all the parameters in the huffman tree as well during the training.

(2) My second question is about the correctness of my approach. Is my approach to train a Word2Vec model like this a correct one?

Also since we do not have any implementation of hierarchical softmax in the pytorch library, I’d like to contribute a tested version of this technique. Can anyone guide me regarding what should I do next to make this contribution?

I searched the internet for some references, but I only found this as an example pytorch implemention. Maybe you can cross check your code with this. https://github.com/leimao/Two_Layer_Hierarchical_Softmax_PyTorch/blob/master/utils.py#L98