I’m trying to build a GRU network using simpler blocks of nn. While using nn.GRU works really fast, my version is painfully slow, even though it also should run on GPU. I’m trying to understand why that is.
Here’s the nn.GRU version:
def __init__(self, in_dim, h_dim, out_dim, n_layers, dropout=0):
self.rnn = nn.GRU(in_dim, h_dim, n_layers)
self.decoder = nn.Linear(h_dim, out_dim)
def forward_(self, input: Tensor, hidden_state: Tensor=None):
output, hidden = self.rnn(input, hidden_state)
layer_output = self.decoder(output)
return layer_output, hidden
And here’s my “manual” version:
def __init__(self, in_dim, h_dim, out_dim, n_layers, dropout=0):
for i in range(n_layers):
# params = dict()
if i != 0:
in_dim = h_dim
params[f'L{i}W1'] = nn.Linear(in_dim, h_dim)
params[f'L{i}W2'] = nn.Linear(h_dim, h_dim, bias=False)
params[f'L{i}W3'] = nn.Linear(in_dim, h_dim)
params[f'L{i}W4'] = nn.Linear(h_dim, h_dim, bias=False)
params[f'L{i}W5'] = nn.Linear(in_dim, h_dim)
params[f'L{i}W6'] = nn.Linear(h_dim, h_dim, bias=False)
params['W7'] = nn.Linear(h_dim, out_dim)
for param in params:
self.add_module(param, params[param])
self.params = params
def forward(self, input: Tensor, hidden_state: Tensor=None):
if hidden_state is None:
layer_states.append(torch.zeros(batch_size, self.h_dim, device=input.device))
for i in range(self.n_layers):
layer_input = layer_middle
layer_middle = torch.zeros(batch_size, seq_len, self.h_dim, device=input.device)
params = self.params
for j in range(seq_len):
x = layer_input[:, j, :]
z = F.sigmoid(params[f'L{i}W1'](x) + params[f'L{i}W2'](layer_states[i]))
r = F.sigmoid(params[f'L{i}W3'](x) + params[f'L{i}W4'](layer_states[i]))
g = F.tanh(params[f'L{i}W5'](x) + r*params[f'L{i}W6'](layer_states[i]))
layer_states[i] = z*layer_states[i] + (1-z)*g
layer_middle[:, j, :] = layer_states[i]
layer_middle = nn.Dropout(self.dropout)(layer_middle)
layer_output = torch.zeros((batch_size, seq_len, self.out_dim), device=input.device)
for j in range(seq_len):
x = layer_middle[:, j, :]
layer_output[:, j, :] = params['W7'](x)