Hello, I am using flex attention in the following way. I have ‘sliding_window_mask’ defined inside forward. ‘forward’ fails to checkpoint because of the nested function definition. However I have to define it this way because window_size does not exist outside forward.
def forward(self, window_size, …):
def sliding_window_mask(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size
self.block_mask = create_block_mask(sliding_window_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, _compile=True)
I have tried:
1: creating sliding_window_mask outside forward, adding a window_size parameter. This fails bc create_block_mask requires sliding_window_mask to have 4 parameters
2: I have tried deleting sliding_window_mask from the checkpoint by overloading getstate but this didn’t work.
any advice would be greatly appreciated, thanks!