I have the following customized GRU class, which unrolls the for loop, how can I use jit to speed it up?
class GRU(tc.nn.Module):
""" GRU with zoneout support, input is batch first
"""
def __init__(self, input_dim, hidden_size, zoneout_p=0):
super().__init__()
self.gru_cell = tc.nn.GRUCell(input_dim, hidden_size)
self.hidden_size = hidden_size
self.zoneout_p = zoneout_p
def forward(self, x):
assert len(x.shape) == 3, x.shape
hidden_states = [self.gru_cell(x[:, 0, :])]
for i in range(1, x.shape[1]):
h_new = self.gru_cell(x[:, i, :], hidden_states[-1]) # BH
if self.training:
s = tc.rand_like(h_new) < self.zoneout_p
h_new = s * hidden_states[-1] + (~s) * h_new
hidden_states.append(h_new)
return tc.cat([hh[:, None, :] for hh in hidden_states], dim=1), hidden_states[-1][None]
I am running the code on GPU, my typical input size is (10000, 50, 7)
I have tried subclassing the class from tc.jit.ScriptModule and decorating the forward method with tc.jit.script_method following Optimizing CUDA Recurrent Neural Networks with TorchScript | PyTorch
But it actually made the code even slower…