Hello. I have built a network with multiple LSTM as an input layer. Each LSTM is built for certain signal, which makes total of 10 LSTMs for each signal id. While training, certain signal and its id arrives and I want to only update weights w.r.t. signals id. Unfortunately, Im facing following error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.
Error is gone if I set
retain_graph=True as backward parameter, but this way training speed decreases dramatically after each forward call in training loop. How do change my training loop so I dont have to use
retain_graph parameter? Heres my train loop code:
for epoch in range(self.epochs): self.main_model.hidden_states = self.main_model.init_hidden() self.main_model.cell_states = self.main_model.init_cells() batch = self.generator.get_sample_batch() for i, batch_el in enumerate(batch): cum_loss = 0 for j, input_signal in enumerate(batch_el): self.freeze_layers() signal_id = int(input_signal - 1) self.unfreeze_layers(signal_id) payload_size = self.signals_per_id[signal_id] input_s = torch.as_tensor(np.array(input_signal[1:1+payload_size], dtype=np.float32). reshape(1, 1, payload_size)).float() reconstruction = self.main_model(input_s, signal_id) loss_value = self.loss(reconstruction, input_s) * self._dataset_inv_freqs[signal_id] loss_value.backward()#retain_graph=True cum_loss += loss_value.item() if (j + 1) % self.update_interval == 0: self.optimizer.step() self.main_model.zero_grad()