Hi,
I would like to apply the code below to nn.Embedding.
However, the forward function still calls the original nn.Embedding forward, instead of calling the one I created.
I would appreciate your help. Could you tell me why the model behaves like this?
class Linear(nn.Linear):
def __init__(self, in_features, out_features):
super(Linear, self).__init__(in_features, out_features)
self.weight.fast = None #Lazy hack to add fast weight link
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast)
else:
out = super(Linear, self).forward(x)
return out
class Embedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim):
super(Embedding, self).__init__(num_embeddings, embedding_dim)
self.weight.fast = None
def forward(self, user_id):
if self.weight.fast is not None:
user_emb = F.embedding(user_id, self.weight.fast)
else:
user_emb = super(Embedding, self).forward(user_id)
return user_emb