Hi @richard, thanks for the concern.
I attached my graph below. I debugged and find that the error may happen in torch.statck
or torch.reshape
operation.
The T3_LSTM() is my original code and it could not do backward().
In T4_LSTM() I comment final output and it could do the backward(). Not sure if it is a bug or a mis-usage.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
BATCH_SIZE = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class T1_LSTM(nn.Module):
def __init__(self, input_channels, lstm_hidden_size=100, lstm_num_layers=2):
super(T1_LSTM, self).__init__()
self.lstm_1 = nn.LSTM(input_channels,
lstm_hidden_size,
lstm_num_layers,
bias=False,
bidirectional=True)
self.lstm_2 = nn.LSTM(input_channels,
lstm_hidden_size,
lstm_num_layers,
bias=False,
bidirectional=True)
self.lstm_1_states = (
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
)
self.lstm_2_states = (
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
)
def forward(self, first_chain, second_chain):
lstm_1_out, self.lstm_1_states = self.lstm_1(
first_chain, self.lstm_1_states)
lstm_2_out, self.lstm_2_states = self.lstm_2(
second_chain, self.lstm_1_states)
return lstm_1_out[-1], self.lstm_1_states, lstm_2_out[-1], self.lstm_2_states
class T2_LSTM(nn.Module):
def __init__(self, input_channels, lstm_hidden_size=100, lstm_num_layers=2):
super(T2_LSTM, self).__init__()
self.lstm = nn.LSTM(input_channels,
lstm_hidden_size,
lstm_num_layers,
bias=False,
bidirectional=True)
self.lstm_states = (
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
torch.zeros((lstm_num_layers*2, BATCH_SIZE, lstm_hidden_size)).to(device),
)
def forward(self, input, t1_states):
lstm_out, self.lstm_states = self.lstm(input, t1_states)
return lstm_out[-1]
class T3_LSTM(nn.Module):
def __init__(self, sequence_input_channels, lstm_hidden_size=100, lstm_num_layers=2):
super(T3_LSTM, self).__init__()
self.t1_lstm_1 = T1_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
self.t1_lstm_2 = T1_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
self.t2_lstm = T2_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
def forward(self, input1, input2, input3):
alpha_out_1, alpha_states_1, alpha_out_2, alpha_states_2 = self.t1_lstm_1(
input2, input3)
beta_out_1, beta_states_1, beta_out_2, beta_states_2 = self.t1_lstm_2(
input3, input2)
sum_states = (
torch.add(torch.add(alpha_states_1[0], alpha_states_2[0]), torch.add(beta_states_1[0], beta_states_2[0])),
torch.add(torch.add(alpha_states_1[1], alpha_states_2[1]), torch.add(beta_states_1[1], beta_states_2[1])),
)
p_out = self.t2_lstm(input1, sum_states)
h_out = torch.add(torch.add(alpha_out_1, alpha_out_2), torch.add(beta_out_1, beta_out_2))
# stack alogn width
out = torch.stack([torch.reshape(p_out, (BATCH_SIZE, 1, p_out.shape[1])),
torch.reshape(h_out, (BATCH_SIZE, 1, h_out.shape[1]))], dim=1)
return out
class T4_LSTM(nn.Module):
def __init__(self, sequence_input_channels, lstm_hidden_size=100, lstm_num_layers=2):
super(T4_LSTM, self).__init__()
self.t1_lstm_1 = T1_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
self.t1_lstm_2 = T1_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
self.t2_lstm = T2_LSTM(sequence_input_channels,
lstm_hidden_size, lstm_num_layers)
def forward(self, input1, input2, input3):
alpha_out_1, alpha_states_1, alpha_out_2, alpha_states_2 = self.t1_lstm_1(
input2, input3)
beta_out_1, beta_states_1, beta_out_2, beta_states_2 = self.t1_lstm_2(
input3, input2)
sum_states = (
torch.add(torch.add(alpha_states_1[0], alpha_states_2[0]), torch.add(beta_states_1[0], beta_states_2[0])),
torch.add(torch.add(alpha_states_1[1], alpha_states_2[1]), torch.add(beta_states_1[1], beta_states_2[1])),
)
p_out = self.t2_lstm(input1, sum_states)
h_out = torch.add(torch.add(alpha_out_1, alpha_out_2), torch.add(beta_out_1, beta_out_2))
# # stack alogn width
# out = torch.stack([torch.reshape(p_out, (BATCH_SIZE, 1, p_out.shape[1])),
# torch.reshape(h_out, (BATCH_SIZE, 1, h_out.shape[1]))], dim=1)
return p_out, h_out
def test_t3():
t = T3_LSTM(31)
a = t.forward(torch.randn(15, 1, 31),
torch.randn(115, 1, 31),
torch.randn(125, 1, 31))
print(a.shape)
a.backward(torch.randn(a.shape))
def test_t4():
t = T4_LSTM(31)
a,b = t.forward(torch.randn(15, 1, 31),
torch.randn(115, 1, 31),
torch.randn(125, 1, 31))
print(a.shape)
a.backward(torch.randn(a.shape), retain_graph=True)
b.backward(torch.randn(b.shape))
# error happens
test_t3()
# works
test_t4()