Hi James,
Thanks for your explanation. Here is more of my code. (sorry for some copyright issue, I can’t post the entire code here). From the printed information, I can see that the LSTM does go to both modes (flag == true or flag == false). But LSTM sees the input regardless of the mode.
class LSTM_MODEL(nn.Module):
def __init__(self):
super(LSTM_MODEL, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(3380, 1024)
self.fc12 = nn.Linear(2048,1024)
self.fc21 = nn.Linear(1024, 512)
self.fc22 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512,1024)
self.fc4 = nn.Linear(1024, 3380)
self.convtranspose1 = nn.ConvTranspose2d(10, 1, kernel_size = 5)
self.convtranspose2 = nn.ConvTranspose2d(20, 10, kernel_size = 5)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.unpool = nn.MaxUnpool2d(2)
self.lstm1 = nn.LSTMCell(512, 1000)
self.lstm2 = nn.LSTMCell(1000, 512)
def feature(self, x):
......
def decode(self, z, idx1, idx2):
.......
def forward(self, input, future = 0, train_flag = 1):
if train_flag == 1:
print("training")
else:
print("testing")
outputs = []
h_t = Variable(torch.zeros(200, 1000).float(), requires_grad=False)
c_t = Variable(torch.zeros(200, 1000).float(), requires_grad=False)
h_t2 = Variable(torch.zeros(200, 512).float(), requires_grad=False)
c_t2 = Variable(torch.zeros(200, 512).float(), requires_grad=False)
FEATURE_null = Variable(torch.zeros(200,512).float(),requires_grad=False)
if args.cuda:
h_t = h_t.cuda()
c_t = c_t.cuda()
h_t2 = h_t2.cuda()
c_t2 = c_t2.cuda()
FEATURE_null = FEATURE_null.cuda()
###############################
# LSTM
###############################
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
input_t = input_t.squeeze(1).contiguous()
x_feature, idx1, idx2 = self.feature(input_t)
# important: arbitrarily choose 0 or 1.
if train_flag == 1: # training: arbitrarily choosing mode
flag = random.choice([True, False])
else: # test: prediction
flag = False
if i == 0: #( first time step always gets True Flag)
flag = True
## the following is the lstm part.
if flag == True:
print("flag is True")
h_t, c_t = self.lstm1(x_feature, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
else:
print("flag is False")
h_t, c_t = self.lstm1(FEATURE_null, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
recon_x = self.decode(c_t2, idx1, idx2)
outputs += [recon_x]
for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(FEATURE_null, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
recon_x = self.decode(c_t2, idx1, idx2)
outputs += [recon_x]
outputs = torch.stack(outputs, 1).squeeze(1)
return outputs, mu_list, logvar_list, lstm_hidden
model = VAE()
if args.cuda:
model.cuda()
def loss_function():
.........