IF-ELSE statement in LSTM

Hi, I am wondering if it is possible to include an IF-ELSE statement in the LSTM part of the code. Here is the section:

        if flag == True:
            print("reconstruction")
            h_t, c_t = self.lstm1(z, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
        else:
            print("prediction")
            h_t, c_t = self.lstm1(z_null, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))

where z_null is a all-zero vector with the same shape as z.

So what I want to do is that at each time-step, the LSTM could either have an input or only use the information from previous hidden state.

Since PyTorch is a dynamic network tool, I assume it should be able to do this. But during my experiment, seems like the LSTM actually gets the input at each time-step, regardless of the IF-ELSE statement.

Could someone help me with this question? Thanks!

You probably want to use nn.LSTMCell or explicitly pass one timestep of data at a time to nn.LSTM.

Hi James,

Thanks for your answer. yes I am using the LSTMCell. The structure is similar to
self.lstm = nn.LSTMCell(in_dim, out_dim)

The vanilla LSTM works fine. Can we use an IF-ELSE statement to control if the LSTMCell gets only the previous hidden state or gets both hidden state and an actual input?

Again, thanks for your help!

How did you get the flag? Can you print your flag to see if it’s changing?

Hi Ruotian,

Sorry I just posted the core part of the code. The flag is randomly generated to be either True or False.

            flag = random.choice([True, False])

And yes, it goes to both the “reconstruction” part and the “prediction” part of the code. Is this type of code supported by PyTorch? I noticed that there is an example code using the similar idea https://github.com/jcjohnson/pytorch-examples#pytorch-control-flow–weight-sharing.

Thank you for your help.

Eric

Yes, it’s supported.

Yeah, what you’re trying to do is absolutely supported and should work. If you’re still running into problems, can you paste a little more of your code? Also make sure everything that needs to happen for every iteration (like the random.choice) is in forward and not __init__.

Hi James,

Thanks for your explanation. Here is more of my code. (sorry for some copyright issue, I can’t post the entire code here). From the printed information, I can see that the LSTM does go to both modes (flag == true or flag == false). But LSTM sees the input regardless of the mode.

class LSTM_MODEL(nn.Module):
    def __init__(self):
        super(LSTM_MODEL, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(3380, 1024)
        self.fc12 = nn.Linear(2048,1024)
        self.fc21 = nn.Linear(1024, 512)
        self.fc22 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512,1024)
        self.fc4 = nn.Linear(1024, 3380)
        self.convtranspose1 = nn.ConvTranspose2d(10, 1, kernel_size = 5)
        self.convtranspose2 = nn.ConvTranspose2d(20, 10, kernel_size = 5)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.unpool = nn.MaxUnpool2d(2)
        self.lstm1 = nn.LSTMCell(512, 1000)
        self.lstm2 = nn.LSTMCell(1000, 512)

    def feature(self, x):
         ......

    def decode(self, z, idx1, idx2):
       .......
   

    def forward(self, input, future = 0, train_flag = 1):

        if train_flag == 1:
            print("training")
        else:
            print("testing")

        outputs = []

        h_t = Variable(torch.zeros(200, 1000).float(), requires_grad=False)
        c_t = Variable(torch.zeros(200, 1000).float(), requires_grad=False)
        h_t2 = Variable(torch.zeros(200, 512).float(), requires_grad=False)
        c_t2 = Variable(torch.zeros(200, 512).float(), requires_grad=False)
        FEATURE_null = Variable(torch.zeros(200,512).float(),requires_grad=False)

        if args.cuda:
            h_t = h_t.cuda()
            c_t = c_t.cuda()
            h_t2 = h_t2.cuda()
            c_t2 = c_t2.cuda()
            FEATURE_null = FEATURE_null.cuda()

###############################
# LSTM
###############################
        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):

            input_t = input_t.squeeze(1).contiguous()
            x_feature, idx1, idx2 = self.feature(input_t)

            # important: arbitrarily choose 0 or 1.

            if train_flag == 1: # training: arbitrarily choosing mode
                flag = random.choice([True, False])
            else: # test:  prediction
                flag = False

            if i == 0: #( first time step always gets True Flag)
                flag = True
            ## the following is the lstm part.
            if flag == True:
                print("flag is True")
                h_t, c_t = self.lstm1(x_feature, (h_t, c_t))
                h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            else:
                print("flag is False")
                h_t, c_t = self.lstm1(FEATURE_null, (h_t, c_t))
                h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            recon_x = self.decode(c_t2, idx1, idx2)
            outputs += [recon_x]

        for i in range(future):# if we should predict the future

            h_t, c_t = self.lstm1(FEATURE_null, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            recon_x = self.decode(c_t2, idx1, idx2)
            outputs += [recon_x]

        outputs = torch.stack(outputs, 1).squeeze(1)
        return outputs, mu_list, logvar_list, lstm_hidden

model = VAE()
if args.cuda:
    model.cuda()

def loss_function():
        .........