How to feed softmax output to embedding as rnn input

I’m implementing this network:

how can I feed softmax output to the embedding layer?

How about applying argmax to the output of softmax as follows?

# -*- coding: utf-8 -*-

Automatically generated by Colaboratory.

Original file is located at

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):

    def __init__(self, num_classes=10, dim_embed=128):
        self.softmax = nn.Softmax(1)
        self.embed = nn.Embedding(num_embeddings=num_classes, embedding_dim=dim_embed)

    def forward(self, x):
        prob = self.softmax(x)
        indices = torch.argmax(prob, dim=1)
        return self.embed(indices)

model = Model()
# N: 32, C: 10
sample_input = torch.randn(32, 10)
# This `no_grad` shouldn't be in your training code :slightly_smile:
with torch.no_grad():
    embeddings = model(sample_input)