It seems like the nn.Transformer
module might benefit from using a complementary nn.Generate
module that allows for different decoding/sampling/generation methods to be used easily with the existing framework.
For instance,
class DummyModel(nn.Module, nn.Generate):
def __init__(self, *args, **kwargs):
# set some model ops
self.input = nn.Linear(...)
self.embedding = nn.Embedding(...)
self.model = nn.Transformer(...)
def forward(self, x):
return self.model(self.embedding(self.input(x)))
model = DummyModel(...)
# perform training
for x,y in batch:
y_hat = model(x)
loss = criterion(y, y_hat)
# perform inference
generated_output = model.generate(..., method=['greedy', 'beam-search', 'top-k', 'nucleus', etc.])
Are there any plans to have something like that implemented in the future?