I am using a 2 layer stacked lstm with a width of 128 in PyTorch. Since I want to use more layers and change the width of the layers, I have written the lstm class I am using in a generalized form. When I use the general model with width=128
, depth=2
I get the exact same results in each epoch (with set random state). The only difference is the performance of both models. While the hardcoded 2 layer lstm takes ~60s per epoch, the general model takes ~90s. How is this possible?
Hard coded model (60s per epoch):
class Network(nn.Module):
def __init__(
self,
input_size,
width,
depth,
device
):
super(Network, self).__init__()
self.input_size = input_size
self.device = device
self.lstm_1 = nn.LSTMCell(self.input_size, 128)
self.lstm_2 = nn.LSTMCell(128, 128)
self.linear_1 = nn.Linear(128, 32)
self.linear_2 = nn.Linear(32, 2)
self.dropout_1 = nn.Dropout(0.2)
def forward(self, data):
h_t_1 = torch.zeros(data.size(0), 128).to(self.device)
c_t_1 = torch.zeros(data.size(0), 128).to(self.device)
h_t_2 = torch.zeros(data.size(0), 128).to(self.device)
c_t_2 = torch.zeros(data.size(0), 128).to(self.device)
for time_step in data.split(1, dim=1):
h_t_1, c_t_1 = self.lstm_1(time_step.view(data.size(0), self.input_size), (h_t_1, c_t_1))
h_t_2, c_t_2 = self.lstm_2(self.dropout_1(h_t_1), (h_t_2, c_t_2))
output = self.dropout_1(self.linear_1(self.dropout_1(h_t_2)))
output = self.linear_2(output)
mean = output[..., 0][..., None]
std = torch.clamp(output[..., 1][..., None], min=0.01)
norm_dist = torch.distributions.Normal(mean, std)
return norm_dist
General model (90s per epoch):
class Network(nn.Module):
def __init__(
self,
input_size,
width,
depth,
device
):
super(Network, self).__init__()
self.input_size = input_size
self.width = width
self.depth = depth
self.device = device
self.lstm_1 = nn.LSTMCell(self.input_size, self.width)
for i in range(self.depth - 1):
setattr(self, f'lstm_{i+2}', nn.LSTMCell(self.width, self.width))
self.linear_1 = nn.Linear(self.width, 32)
self.linear_2 = nn.Linear(32, 2)
self.dropout_1 = nn.Dropout(0.2)
def forward(self, data):
h_t_1 = torch.zeros(data.size(0), self.width).to(self.device)
c_t_1 = torch.zeros(data.size(0), self.width).to(self.device)
for i in range(self.width - 1):
locals()[f'h_t_{i+2}'] = torch.zeros(data.size(0), self.width).to(self.device)
locals()[f'c_t_{i+2}'] = torch.zeros(data.size(0), self.width).to(self.device)
for time_step in data.split(1, dim=1):
h_t_1, c_t_1 = self.lstm_1(time_step.view(data.size(0), self.input_size), (h_t_1, c_t_1))
for i in range(self.depth - 1):
locals()[f'h_t_{i+2}'], locals()[f'c_t_{i+2}'] = getattr(self, f'lstm_{i+2}')(self.dropout_1(locals()[f'h_t_{i+1}']), (locals()[f'h_t_{i+2}'], locals()[f'c_t_{i+2}']))
output = self.dropout_1(self.linear_1(self.dropout_1(locals()[f'h_t_{self.depth}'])))
output = self.linear_2(output)
mean = output[..., 0][..., None]
std = torch.clamp(output[..., 1][..., None], min=0.01)
norm_dist = torch.distributions.Normal(mean, std)
return norm_dist