Hey there!
I am trying to implement a model where each ‘pixel’ of the input should run through its own GRU cell (instead of using the feature input of GRU by flattening and flattening the image).
This forces me to loop over a ModuleList of GRU’s to forward pass each pixel.
I am trying to speed it it up using TorchScript but without success. : (
I want to use this custom module during training.
This is my module:
class ParallelGruList(nn.Module):
def __init__(self, input_size, hidden_size, num_grus, device):
super().__init__()
self.num_grus = num_grus
self.input_size = input_size
self.hidden_size = hidden_size
self.device = device # passing the device for torch script
self.grus = nn.ModuleList(self.get_gru() for _ in range(self.num_grus)).to(device)
def get_gru(self):
return nn.GRU(input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=1,
batch_first=True,
bias=False)
def forward(self, x):
batch_size = x.shape[0]
sequence_length = x.shape[1]
output = torch.zeros((batch_size, sequence_length, 10, self.num_grus), dtype=torch.float32, device=self.device)
for i, gru in enumerate(self.grus):
output[:, :, :, i], _ = gru(x[:, :, :, i], None)
return output
And this is how I test it:
def jit_gru(state, grus):
start = time.time()
r = grus(state)
return time.time() - start
data = torch.rand(1, 100, 3, 14 * 14, dtype=torch.float32, device=device)
print('init gru')
if __name__ == '__main__':
print('\n')
data = torch.rand(1, 100, 3, 14 * 14, dtype=torch.float32, device=device)
print('init gru')
grus = ParallelGruList(input_size=3, hidden_size=10, num_grus=14 * 14, device=device).to(device)
print('run test')
h = []
for i in range(10):
h.append(jit_gru(data, grus))
print('exe time smart jit', np.mean(h))
Running it like this will generate the following output:
init gru
run test
exe time smart jit 0.2266002893447876
To speed things up I tried to initialize the module by:
grus = torch.jit.script(ParallelGruList(input_size=3, hidden_size=10, num_grus=14 * 14, device=device).to(device))
But this makes it a little slower:
exe time smart jit 0.296830940246582
I also tried to make the ModuleList a constant but this leads to the error:
UserWarning: 'grus' was found in ScriptModule constants, but it is a non-constant submodule. Consider removing it.
warnings.warn("'{}' was found in ScriptModule constants, "
I also tried to decorate the get_gru() method with @torch.jit.export but that but that caused the error:
Class GRU does not have an __init__ function defined:
I also tried a couple of other things that made the whole thing usually even slower. I’ve read the JIT reference guide and the documentation of torch.jit.script but I could not find anything that helped.
At this point I am thoroughly clueless.
Could someone point me in the right direction or let me know what I am missing?
Best,
Constantin