How to compile a RNN loop once?

Hi,

I’d like to use the feature of torch.compile that unrolls for loops to implement RNNs. When using multiple identical layers of the same RNN I’ve noticed compilation time grows proportional to the number of layers: there is no reuse of the code which uses a lot of time and memory.

In practice this means I can’t compile a reasonably large RNN successfully.

Example model with complation metrics: compile_rnns.py · GitHub

I tried looping through layers in eager mode and only using the compiled forward function, however the function was specialized for each layer anyway: Revisions · compile_rnns.py · GitHub

TORCH_LOGS=guards mention the following guards. ___check_obj_id looks suspicious, does it mean the specialization is to the particular tensor?

GUARDS:
___check_type_id(L['T'], 94350706369312)                                      
L['T'] == 1024                                                                
hasattr(L['x'], '_dynamo_dynamic_indices') == False           # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls 
___check_obj_id(L['input'], 140194674075456)                
L['input'].training == True                                                   
___check_obj_id(L['forget'], 140194674076224)               
L['forget'].training == True                                                  
hasattr(L['hidden'], '_dynamo_dynamic_indices') == False      # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls 
___check_obj_id(L['output'], 140194674077184)               
L['output'].training == True                                                  
utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:374 in init_ambient_guards
___skip_backend_check() or ___current_backend() == ___lookup_backend(140196638382528)  # _dynamo/output_graph.py:380 in init_ambient_guards
check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=True, size=[8, 1024, 64], stride=[65536, 64, 1])  # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls 
check_tensor(L['hidden'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[8, 64], stride=[64, 1])  # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls 

Could you advise on the best tactic to successfully compile long loops?

Thanks

Turns out the hunch to think about ___check_obj_id was in the right direction. The code I wrote was passing nn.Module objects as arguments. I changed the arguments to have type nn.Parameter, which made dynamo produce more check_tensor guards.

The result is available here: compile_rnns3.py · GitHub

The compiled RNN is about 10x faster than eager for sequence length 128 and about 8x faster for sequence length 256. Compilation crashed due to stack explosion somewhere in torch.fx on longer sequences.

Btw love torch.export.