I have a function with a for loop which compiles just fine when compiling the function by itself, but takes hours to compile when included in a much larger forward pass of an entire transformer.
The function is basically a Finite State Machine for computing what “segment type” tokens belong to, and as such we iterate over the sequence dimension of input tokens.
def compute_segments( self, tokens: torch.Tensor ):
# Tensors of current state and id for each sequnce in the batch
current_state = torch.zeros( [ tokens.shape[0] ], dtype=torch.int, device=tokens.device )
current_seg_id = torch.zeros( [ tokens.shape[0] ], dtype=torch.int, device=tokens.device )
# Tensors of all states and ids for each element of each sequence in the batch
states = torch.zeros_like( tokens, dtype=torch.int )
seg_ids = torch.zeros_like( tokens, dtype=torch.int )
# Get idxs of start, end and newline
im_start_arr = tokens == self.sep_token_id
im_end_arr = tokens == self.cls_token_id
newline_arr = tokens == self.new_token_id
# Loop over all tokens
for i in range( tokens.shape[-1] ):
# If token is <|im_start|>
im_start = im_start_arr[ :, i ]
# If token is <|im_end|>
im_end = im_end_arr[ :, i ]
# If token is \n # TODO: check for multiple types of newline perhaps?
newline = newline_arr[ :, i ]
# 4->0 if \n
current_state.masked_fill_( ( current_state == 4 ).logical_and( newline ), 0 )
# 3->4 if im_end
current_state.masked_fill_( ( current_state == 3 ).logical_and( im_end ), 4 )
# 2->3 if anything
current_state.masked_fill_( ( current_state == 2 ), 3 )
# 1->2 if \n
current_state.masked_fill_( ( current_state == 1 ).logical_and( newline ), 2 )
# 0->1 if im_start
current_state.masked_fill_( ( current_state == 0 ).logical_and( im_start ), 1 )
# If im_start is seen increment seg_id
current_seg_id += ( im_start ).int()
states[ :, i ] = current_state
seg_ids[ :, i ] = current_seg_id
segment_mask = torch.isin( states, torch.tensor( [ 1, 2, 3, 4 ] if self.include_prefix else [ 3, 4 ], device=states.device ) )
class_mask = im_end_arr
return states, seg_ids, segment_mask, class_mask
I tested this function in isolation with torch.compile and it takes seconds to compile, however when included in a forward pass of an entire transformer model for which the entire forward pass is compiled this massively inflates the compile time. I used the debug envars to see what dynamo is doing and it looks like the entire for loop has been unrolled because the terminal gets filled with 1000s of instances of dynamo tracing masked_fill_
, and once it hits the inductor compile step it hangs. However when annotating the function with @torch._dynamo.disable
this all goes away and the compile phase takes only minutes but the uncompiled loop is significantly slower.
I fixed this by only calling this FSN function when self.include_prefix
is True, because if this is False I can compute the segments using a different function without the for loop, but that’s kinda annoying.
Does torch have some sort of loop primitive which either never unrolls, or unrolls with a user specified stride?