Hey,
I want my model to take the input sequence as input 3 times. Once uni-gram, once bi-gram and once tri-gram. Each of the three will pass through it’s own lstm and then it will be concatinated.
But I get an error when Im using the embedding layers:
in the model’s init() :
vocab_size_n1 = pow(26, 1) + 2 # A + "<space>" + "<pad-0>"
self.embed_n1 = nn.Embedding(vocab_size_n1, self.embedding_size)
vocab_size_n2 = pow(26, 2) + 3 # A^2 + "<space>" + "<pad-0>"
self.embed_n2 = nn.Embedding(vocab_size_n2, self.embedding_size)
vocab_size_n3 = pow(26, 3) + 3 # A^3 + "<space>" + "<pad-0>"
self.embed_n3 = nn.Embedding(vocab_size_n3, self.embedding_size)
in forward:
x_n1 = self.embed_n1(x_n1.type(torch.LongTensor).to(device=device))
x_n2 = self.embed_n2(x_n2.type(torch.LongTensor).to(device=device))
x_n3 = self.embed_n3(x_n3.type(torch.LongTensor).to(device=device))
So the first n1 embedding is working fine. But in the second and third is giving me a cuda assertion error
which, in this case, means that there are more unique characters in the input than vocab_size. I checked that multiple times and everything is correct. When I delete any two of the three embeddings it is working. Turns out, when I rearange the embeddings like this:
x_n3 = self.embed_n3(x_n3.type(torch.LongTensor).to(device=device))
x_n2 = self.embed_n2(x_n2.type(torch.LongTensor).to(device=device))
x_n1 = self.embed_n1(x_n1.type(torch.LongTensor).to(device=device))
is works. This shows me that when the first embedding is used, the next two will not use the embedding I initialized for them but also the first. Therefore, when starting with n1, the vocab size of n1 is smaller than all unique characters in bi- and tri-gram and I get the assertion error (index out of bounds basically). But when Im starting with the embedding with the biggest vocab size (n3 tri-gram) it will also work for the smaller two. I guess it wastes to much memory when the embedding uses a vocab size which is bigger than the actual amount unique characters. What causes that and how can I fix it? Thank you