Expand an existing Embedding and linear layer - NaN loss value


I am working on an image captioning stuff.
Now as the size of vocab increases, I have to expand the Embedding layer and my last linear layer. My decoder is like:

def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
 """Set the hyper-parameters and build the layers."""
 super(DecoderRNN, self).__init__()
 self.embed = nn.Embedding(vocab_size, embed_size)
 self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
 self.linear = nn.Linear(hidden_size, vocab_size)
 self.max_seg_length = max_seq_length

And now when the vocab increases its size by new_dim_num, I copy the old weights and concatenate with the new weights to build a new Embedding and linear layer like this:

from copy import deepcopy
# expand decoder layers
# new_dim_num is the additional number of entries to vocab
old_embed = deepcopy(decoder.embed.weight.data)
new_row_embed = torch.Tensor(new_dim_num, args.embed_size).to(device)
new_embed = torch.cat((old_embed, new_row_embed), 0)
old_linear = deepcopy(decoder.linear.weight.data)
new_row_linear = torch.Tensor(new_dim_num, args.hidden_size).to(device)
new_linear = torch.cat((old_linear, new_row_linear), 0)
old_bias = deepcopy(decoder.linear.bias.data)
new_row_bias = torch.Tensor(new_dim_num).to(device)
new_bias = torch.cat((old_bias, new_row_bias))

 decoder.embed = nn.Embedding.from_pretrained(new_embed)
 decoder.linear.weight = nn.Parameter(new_linear)
 decoder.linear.bias = nn.Parameter(new_bias)
 decoder.linear.out_features = vocab_size

 for param in decoder.parameters():

But now the loss is always NaN or really large. I’ve check the output of RNN of the output value is abnormally huge.
Could someone point what is wrong here?
Thank you.

I think my approach was wrong.
This is how I make this work:

old_embed = decoder.embed.weight.data
old_weight = decoder.linear.weight.data
old_bias = decoder.linear.bias.data

 decoder.embed = nn.Embedding(vocab_size, args.embed_size)
 decoder.linear = nn.Linear(args.hidden_size, vocab_size)

 decoder.embed.weight.data[:tmp_dim, :] = old_embed
 decoder.linear.weight.data[:tmp_dim] = old_weight
 decoder.linear.bias.data[:tmp_dim] = old_bias

much more easy and elegant!