Can you use sparse encodings for word2vec model or do you NEED one-hot encoding?

I am trying to train a seq2seq/encoder-decoder word2vec model using pre-trained glove embeddings. The embedding matrix contains 400k words with 100 dimensions.

Each training example is a set of sparse ints. There are a few special tokens for unknown words, end of sequence and end of message (the LSTM should hopefully output multiple messages in a sequence). Here’s an example:

original_string = "This is a test, Hello world! rsegklp EOMTOKEN EOSTOKEN"
string_to_int = [42, 19, 12, 733, 6, 13080, 90, 810, 1, 3, 4]
int_to_string = ['this', 'is', 'a', 'test', ',', 'hello', 'world', '!', '<unk>', 'EOMTOKEN', 'EOSTOKEN']

An SOS_TOKEN for start of sequence, and 0 padding also gets added later.

So what gets fed to the network is a tensor like this (one training example):
[ 319, 38505, 7, 247, 11, 132, 223, 474, 5, 142, 8, 5, 2074, 68, 14, 12, 1893, 10, 82, 12323, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

My network consists of an embedding layer, encoder LSTM, decoder LSTM, dense network followed by a softmax. The output of this is a 400k (number of tokens in word embedding) vector at each timestep of the probability of each word. I use argmax to convert back to a sparse int, similar to my inputs as shown above.

class LSTMNetwork(nn.Module):
    def __init__(self, lstm_hidden_dim, embeddings, max_output_len):
        super().__init__()
        self.vocab_size = embeddings.shape[0]
        self.max_output_len = max_output_len
        self.embedding_dim = embeddings.shape[1] 
        self.lstm_hidden_dim = lstm_hidden_dim
        
        self.embedding_layer = nn.Embedding.from_pretrained(torch.from_numpy(embeddings), freeze=True)
        self.encoder_lstm_layer = nn.LSTMCell(self.embedding_dim, self.lstm_hidden_dim) # features, hidden_dimension
        self.decoder_lstm_layer = nn.LSTMCell(self.embedding_dim, self.lstm_hidden_dim)
        self.fc1 = nn.Linear(self.lstm_hidden_dim, self.vocab_size) # Dense should output 1 value
      
    def forward(self, x, y_input):
        if len(x.shape) == 1: # if this function has been given 1 example, add m dimension.
            x = x.unsqueeze(0)
        batch_size = x.shape[0]

        assert self.max_output_len == y_input.shape[1]
        
        # Run inputs through embedding layer
        enc_input = self.embedding_layer(x).to(torch.float32) # embedding layer outputs float64, we need 32
        dec_input = self.embedding_layer(y_input).to(torch.float32)
        
        # Step up initial states:
        h_t = torch.zeros(batch_size, self.lstm_hidden_dim, dtype=torch.float32).to(device)
        c_t = torch.zeros(batch_size, self.lstm_hidden_dim, dtype=torch.float32).to(device)
        
        # RNN Loop
        enc_input = torch.transpose(enc_input, 1, 0) # The data_loader passes in (m, time_step, n) but we need a list of loop through (m, n) at each timestep. This line rearranges to (time_step, m, n) so we can loop through easily.
        dec_input = torch.transpose(dec_input, 1, 0)
                
        # Encoder
        for x_t in enc_input:
            h_t, c_t = self.encoder_lstm_layer(x_t, (h_t, c_t)) # LSTM cell passes h_t and c_t forward.

        # Decoder        
        outputs = []
        for i, y_inp_t in enumerate(dec_input):
            h_t, c_t = self.decoder_lstm_layer(y_inp_t, (h_t, c_t))
            y_out_t = self.fc1(h_t) # At each time step, pass h_t through nn & softmax/sigmoid for output.
            y_out_t = F.softmax(y_out_t, dim=0)
            print(f"Pre argmax: {y_out_t.requires_grad=}")
            y_out_t = torch.argmax(y_out_t, dim=1, keepdim=True).to(torch.float32)
            print(f"Post argmax: {y_out_t.requires_grad=}")
            outputs.append(y_out_t) # append to list of timesteps.

        y_pred = torch.concat(outputs, dim=1) # concatenate all timesteps to one tensor
        
        return y_pred

The train loop is fairly standard:

        # Train Loop
        model.train()
        current_loss = 0 # used to track average loss (per training example) for each batch
        
        for i, (x_enc_inp, y_dec_inp, y_truth) in enumerate(train_loader):
            y_pred = model(x_enc_inp, y_dec_inp)
            loss = criterion(y_pred, y_truth.to(torch.float32)) # truths are ints, need to be float32

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            current_loss += loss

using criterion = nn.CrossEntropyLoss()

When training I am getting an error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

And after debugging I have found that the argmax function in the decoder section of the forward pass sets requires_grad=False.

It makes sense that you cannot differentiate argmax. And also makes sense that loss will not calculate properly without probabilities for every word.

Do I need to convert the inputs and outputs to one-hot encodings and avoid using argmax? In the past with Tensorflow I have noticed this massively increases compute time and memory usage. Is there a way to keep sparse data?