Hi,
I’d like to use the feature of torch.compile that unrolls for loops to implement RNNs. When using multiple identical layers of the same RNN I’ve noticed compilation time grows proportional to the number of layers: there is no reuse of the code which uses a lot of time and memory.
In practice this means I can’t compile a reasonably large RNN successfully.
Example model with complation metrics: compile_rnns.py · GitHub
I tried looping through layers in eager mode and only using the compiled forward function, however the function was specialized for each layer anyway: Revisions · compile_rnns.py · GitHub
TORCH_LOGS=guards mention the following guards. ___check_obj_id
looks suspicious, does it mean the specialization is to the particular tensor?
GUARDS:
___check_type_id(L['T'], 94350706369312)
L['T'] == 1024
hasattr(L['x'], '_dynamo_dynamic_indices') == False # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls
___check_obj_id(L['input'], 140194674075456)
L['input'].training == True
___check_obj_id(L['forget'], 140194674076224)
L['forget'].training == True
hasattr(L['hidden'], '_dynamo_dynamic_indices') == False # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls
___check_obj_id(L['output'], 140194674077184)
L['output'].training == True
utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:374 in init_ambient_guards
___skip_backend_check() or ___current_backend() == ___lookup_backend(140196638382528) # _dynamo/output_graph.py:380 in init_ambient_guards
check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=True, size=[8, 1024, 64], stride=[65536, 64, 1]) # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls
check_tensor(L['hidden'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[8, 64], stride=[64, 1]) # _dynamo/variables/builder.py:1410 in wrap_fx_proxy_cls
Could you advise on the best tactic to successfully compile long loops?
Thanks