Issues with Ignite Training Loop but fine with plain Pytorch

Hello, I have tabular data and am passing it to a bi-directional LSTM network defined as such:

class BiLSTM(nn.Module):
    def __init__(self, hidden_size=48, num_layers=2, inp_size=5, dropout=0.2):
        super(BiLSTM, self).__init__()
        
        self.hidden_size = hidden_size
        self.input_size = inp_size
        self.num_layers = num_layers
        self.lstm1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, 
                             batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(input_size=self.hidden_size*2, hidden_size=self.hidden_size,
                             batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(96,50)
        self.flatten = nn.Flatten()
        self.dense2 = nn.Linear(50, 1)
        
    def forward(self, x, hidden=None): 
        _ = hidden
        x, hidden = self.lstm1(x, hidden)
        x = self.dropout(x)
        x, _ = self.lstm2(x, _)
        x = self.dropout(x)
        x = self.dense1(x)
        x = self.flatten(x)  
        return self.dense2(x), hidden
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─LSTM: 1-1                              [100, 1, 96]              21,120
├─Dropout: 1-2                           [100, 1, 96]              --
├─LSTM: 1-3                              [100, 1, 96]              56,064
├─Dropout: 1-4                           [100, 1, 96]              --
├─Linear: 1-5                            [100, 1, 50]              4,850
├─Flatten: 1-6                           [100, 50]                 --
├─Linear: 1-7                            [100, 1]                  51
==========================================================================================
Total params: 82,085
Trainable params: 82,085
Non-trainable params: 0
Total mult-adds (M): 0.08
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.19
Params size (MB): 0.33
Estimated Total Size (MB): 0.52
==========================================================================================

I wanted to implement EarlyStopping from the Ignite library but there didn’t seem to be a more simple way than adopting the Ignite workflow. So, I followed the quick start and Early Stopping docs to implement this:

split_num = len(val_loaders)-2
train_loader = zip(train_loaders[split_num], train_label_loaders[split_num])
val_loader = zip(val_loaders[split_num], val_label_loaders[split_num])

net = BiLSTM()
loss_func = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

trainer = create_supervised_trainer(net, optimizer, loss_func)

val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(loss_func)
}
evaluator = create_supervised_evaluator(net, metrics=val_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=1))
def log_training_loss(trainer):
    print(f"Epoch[{trainer.state.epoch}] Loss: {trainer.state.output:.2f}")

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print(f"Training Results - Epoch: {trainer.state.epoch}  Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(f"Validation Results - Epoch: {trainer.state.epoch}  Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")

    
def score_function(engine):
    val_loss = engine.state.metrics['loss']
    return -val_loss

handler = EarlyStopping(patience=50, score_function=score_function, trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)

model_path = PATH + str(split_num) + ".pth"
handler2 = ModelCheckpoint(model_path, filename_prefix='', score_function=score_function, n_saved=2, create_dir=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler2, {'mymodel': net})

trainer.run(train_loader, max_epochs=n_epochs)

I end up getting an error about input size but I have no such error when using the same objects in my pytorch training loops. Can anyone help with this?

(full error traceback)

Current run is terminating due to exception: input.size(-1) must be equal to input_size. Expected 5, got 1
Engine run is terminating due to exception: input.size(-1) must be equal to input_size. Expected 5, got 1

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-e8150550b1bd> in <module>
     49 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler2, {'mymodel': net})
     50 
---> 51 trainer.run(train_loader, max_epochs=n_epochs)

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in run(self, data, max_epochs, epoch_length, seed)
    700 
    701         self.state.dataloader = data
--> 702         return self._internal_run()
    703 
    704     @staticmethod

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _internal_run(self)
    773             self._dataloader_iter = None
    774             self.logger.error(f"Engine run is terminating due to exception: {e}")
--> 775             self._handle_exception(e)
    776 
    777         self._dataloader_iter = None

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
    467             self._fire_event(Events.EXCEPTION_RAISED, e)
    468         else:
--> 469             raise e
    470 
    471     @property

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _internal_run(self)
    743                     self._setup_engine()
    744 
--> 745                 time_taken = self._run_once_on_dataset()
    746                 # time is available for handlers but must be update after fire
    747                 self.state.times[Events.EPOCH_COMPLETED.name] = time_taken

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
    848         except Exception as e:
    849             self.logger.error(f"Current run is terminating due to exception: {e}")
--> 850             self._handle_exception(e)
    851 
    852         return time.time() - start_time

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
    467             self._fire_event(Events.EXCEPTION_RAISED, e)
    468         else:
--> 469             raise e
    470 
    471     @property

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
    831                 self.state.iteration += 1
    832                 self._fire_event(Events.ITERATION_STARTED)
--> 833                 self.state.output = self._process_function(self, self.state.batch)
    834                 self._fire_event(Events.ITERATION_COMPLETED)
    835 

~/miniconda3/envs/py/lib/python3.8/site-packages/ignite/engine/__init__.py in _update(engine, batch)
    100         optimizer.zero_grad()
    101         x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
--> 102         y_pred = model(x)
    103         loss = loss_fn(y_pred, y)
    104         loss.backward()

~/miniconda3/envs/py/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-31-b16c34d32b2c> in forward(self, x, hidden)
     32     def forward(self, x, hidden=None):
     33         _ = hidden
---> 34         x, hidden = self.lstm1(x, hidden)
     35         x = self.dropout(x)
     36         x, _ = self.lstm2(x, _)

~/miniconda3/envs/py/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/miniconda3/envs/py/lib/python3.8/site-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    577             hx = self.permute_hidden(hx, sorted_indices)
    578 
--> 579         self.check_forward_args(input, hx, batch_sizes)
    580         if batch_sizes is None:
    581             result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,

~/miniconda3/envs/py/lib/python3.8/site-packages/torch/nn/modules/rnn.py in check_forward_args(self, input, hidden, batch_sizes)
    528 
    529     def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]):
--> 530         self.check_input(input, batch_sizes)
    531         expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
    532 

~/miniconda3/envs/py/lib/python3.8/site-packages/torch/nn/modules/rnn.py in check_input(self, input, batch_sizes)
    176                     expected_input_dim, input.dim()))
    177         if self.input_size != input.size(-1):
--> 178             raise RuntimeError(
    179                 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
    180                     self.input_size, input.size(-1)))

RuntimeError: input.size(-1) must be equal to input_size. Expected 5, got 1

Hi @cmdkev , thanks for using Ignite !

I think in your case, the training step is a bit different from the basic one used in create_supervised_trainer, I suggest to create the trainer and evaluator in the following and more flexible way:

from ignite.engine import Engine

model = BiLSTM()
loss_func = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
device = "cuda"

def training_step(engine, batch):
    model.train()
    optimizer.zero_grad()

    x, y = batch[0].to(device), batch[1].to(device)
    y_pred, hidden = model(x, trainer.state.hidden)
    loss = loss_func(y_pred, y)

    loss.backward()
    optimizer.step()

    trainer.state.hidden = hidden

    return loss.item()


trainer = Engine(training_step)
trainer.state.hidden = None

@torch.no_grad()
def evaluation_step(engine, batch):
    model.eval()

    x, y = batch[0].to(device), batch[1].to(device)
    y_pred, _ = model(x)
    return y_pred, y

evaluator = Engine(evaluation_step)

Remaining code can be the same.

Hope this helps and, please, let me know if it works or not.

Thanks for this code! I will let you know when I am able to test this out and if it works, but the step with the hidden state definitely looks like it will help.