How to feed softmax output to embedding as rnn input

I’m implementing this network:


how can I feed softmax output to the embedding layer?

How about applying argmax to the output of softmax as follows?

# -*- coding: utf-8 -*-
"""Untitled3.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/gist/crcrpar/adecbb7a838e62a081a05a9cd53be048/untitled3.ipynb
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):

    def __init__(self, num_classes=10, dim_embed=128):
        super().__init__()
        self.softmax = nn.Softmax(1)
        self.embed = nn.Embedding(num_embeddings=num_classes, embedding_dim=dim_embed)

    def forward(self, x):
        prob = self.softmax(x)
        indices = torch.argmax(prob, dim=1)
        return self.embed(indices)

model = Model()
# N: 32, C: 10
sample_input = torch.randn(32, 10)
# This `no_grad` shouldn't be in your training code :slightly_smile:
with torch.no_grad():
    embeddings = model(sample_input)
print(embeddings.shape)