Can model with batch forward update predict single example?

Hi PyTorch users,

I’m still quite new to pytorch, but I’ve spent on this problem sometime already.

So I’ve got this demo model of LSTM which works on batches.

class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_size, output_dim=1, num_layers=2):
        super(LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)

        self.linear = nn.Linear(self.hidden_dim, output_dim)

    def init_hidden(self):
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))

    def forward(self, input):
        lstm_out, self.hidden = self.lstm(input.view(-1, self.batch_size, self.input_dim))
        y_pred = self.linear(lstm_out[-1].view(b_size, -1))
        return y_pred.view(-1)

I tried it on some sine signal and looks like it learns okay.

But I give as input tensor batches of batch_size length. Now I was wondering how to achieve similar to keras method model.predict(X_test) that I can feed to the LSTM model only single example. Any simple solution?

You could just set the batch size to 1 and get a single sample output.

.view(-1) when returning seems superfluous no?

But then I’d have to train using stochastic gradient descent, right? By showing one training example at a time? What I want to achieve is: to train on batches and predict a single example.

I think that key to my concern is a lack of proper understanding of what init_hidden does, as it includes in return dimensions self.batch_size, which are not needed for me at a prediction stage.

1 Like

This sounds like a standard problem: Writing a model for single samples. But on training mini-batches should be the default input for the model.
Is there something like myTrainModel = torch.convertForBatchInputs(myInferenceModel)