Torch.compile can't cope with for loops inside large functions

I have a function with a for loop which compiles just fine when compiling the function by itself, but takes hours to compile when included in a much larger forward pass of an entire transformer.

The function is basically a Finite State Machine for computing what “segment type” tokens belong to, and as such we iterate over the sequence dimension of input tokens.

    def compute_segments( self, tokens: torch.Tensor ):
        # Tensors of current state and id for each sequnce in the batch
        current_state = torch.zeros( [ tokens.shape[0] ], dtype=torch.int, device=tokens.device )
        current_seg_id = torch.zeros( [ tokens.shape[0] ], dtype=torch.int, device=tokens.device )

        # Tensors of all states and ids for each element of each sequence in the batch
        states = torch.zeros_like( tokens, dtype=torch.int )
        seg_ids = torch.zeros_like( tokens, dtype=torch.int )

        # Get idxs of start, end and newline
        im_start_arr = tokens == self.sep_token_id
        im_end_arr = tokens == self.cls_token_id 
        newline_arr = tokens == self.new_token_id

        # Loop over all tokens
        for i in range( tokens.shape[-1] ):
            # If token is <|im_start|>
            im_start = im_start_arr[ :, i ]

            # If token is <|im_end|>
            im_end = im_end_arr[ :, i ]

            # If token is \n # TODO: check for multiple types of newline perhaps?
            newline = newline_arr[ :, i ]

            # 4->0 if \n
            current_state.masked_fill_( ( current_state == 4 ).logical_and( newline ), 0 )

            # 3->4 if im_end
            current_state.masked_fill_( ( current_state == 3 ).logical_and( im_end ), 4 )

            # 2->3 if anything
            current_state.masked_fill_( ( current_state == 2 ), 3 )

            # 1->2 if \n
            current_state.masked_fill_( ( current_state == 1 ).logical_and( newline ), 2 )

            # 0->1 if im_start
            current_state.masked_fill_( ( current_state == 0 ).logical_and( im_start ), 1 )

            # If im_start is seen increment seg_id
            current_seg_id += ( im_start ).int()

            states[ :, i ] = current_state
            seg_ids[ :, i ] = current_seg_id
        
        segment_mask = torch.isin( states, torch.tensor( [ 1, 2, 3, 4 ] if self.include_prefix else [ 3, 4 ], device=states.device ) )
        class_mask = im_end_arr
        
        return states, seg_ids, segment_mask, class_mask

I tested this function in isolation with torch.compile and it takes seconds to compile, however when included in a forward pass of an entire transformer model for which the entire forward pass is compiled this massively inflates the compile time. I used the debug envars to see what dynamo is doing and it looks like the entire for loop has been unrolled because the terminal gets filled with 1000s of instances of dynamo tracing masked_fill_, and once it hits the inductor compile step it hangs. However when annotating the function with @torch._dynamo.disable this all goes away and the compile phase takes only minutes but the uncompiled loop is significantly slower.

I fixed this by only calling this FSN function when self.include_prefix is True, because if this is False I can compute the segments using a different function without the for loop, but that’s kinda annoying.

Does torch have some sort of loop primitive which either never unrolls, or unrolls with a user specified stride?