Can torch.compile support a model with dict holding data from hooks?

Hi, I’m testing new pytorch 2.0 feature torch.compile() for my use case. However, when I use it naively, I encounter following errors:

Traceback (most recent call last):
  File "/home/sist/luoxin/projects/BasicSR/basicsr/train.py", line 272, in <module>
    train_pipeline(root_path)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/train.py", line 189, in train_pipeline
    model.optimize_parameters(current_iter)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/models/perceptual_score_model.py", line 97, in optimize_parameters
    d0 = self.net(self.p0, self.ref)
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 219, in forward
    x_features = self.net(x)
  File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 104, in forward
    features = self.forward_features(x)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 79, in forward_features
    feature = self.extractor.get_block_for_layer(layer)
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in get_block_for_layer
    block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in <graph break in get_block_for_layer>
    block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
  File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in <graph break in get_block_for_layer>
    block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
IndexError: list index out of range

Actually, I use a feature extractor for vision Transformer, the code is below:

class VitExtractor(torch.nn.Module):
    BLOCK_KEY = 'block'
    ATTN_KEY = 'attn'
    PATCH_IMD_KEY = 'patch_imd'
    QKV_KEY = 'qkv'
    HEAD_KEY = 'head'
    PRE_LOGITS_KEY = 'pre_logits'
    CLS_TOKEN_KEY = 'cls_token'
    KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY, HEAD_KEY, PRE_LOGITS_KEY, CLS_TOKEN_KEY]

    def __init__(self, model, return_nodes={}, range_norm=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
        super().__init__()

        if VitExtractor.HEAD_KEY in return_nodes:
            return_nodes[VitExtractor.HEAD_KEY] = [len(model.blocks)]

        max_layer = max([max(return_nodes[key]) for key in return_nodes.keys()])

        self.model = remove_redundant_blocks(model, max_layer, model_type='vit')

        self.return_nodes = return_nodes
        self.use_input_norm = False
        self.range_norm = range_norm
        self.hook_handlers = []
        self.outputs_dict = {}
        for key in VitExtractor.KEY_LIST:
            self.outputs_dict[key] = []
            if key not in self.return_nodes:
                self.return_nodes[key] = []

        self._register_hooks()

        if mean is not None:
            self.use_input_norm = True
            # the mean is for image with range [0, 1]
            self.register_buffer('mean', torch.Tensor(mean).view(1, 3, 1, 1))
            # the std is for image with range [0, 1]
            self.register_buffer('std', torch.Tensor(std).view(1, 3, 1, 1))

    def forward(self, x):
        if self.range_norm:
            x = (x + 1) / 2
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        return self.model.forward(x)

    def _register_hooks(self, **kwargs):
        img_size = to_2tuple(self.model.patch_embed.img_size)
        patch_size = to_2tuple(self.model.patch_embed.patch_size)

        resolution = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])

        for block_idx, block in enumerate(self.model.blocks):

            if block_idx in self.return_nodes[VitExtractor.BLOCK_KEY]:
                self.hook_handlers.append(block.register_forward_hook(self._get_block_hook(resolution)))
            if block_idx in self.return_nodes[VitExtractor.ATTN_KEY]:
                self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
            if block_idx in self.return_nodes[VitExtractor.QKV_KEY]:
                self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook(block.attn.num_heads, resolution)))
            if block_idx in self.return_nodes[VitExtractor.PATCH_IMD_KEY]:
                self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
            if block_idx in self.return_nodes[VitExtractor.CLS_TOKEN_KEY]:
                self.hook_handlers.append(block.register_forward_hook(self._get_cls_token_hook()))

        if VitExtractor.HEAD_KEY in self.return_nodes:
            self.hook_handlers.append(self.model.head.register_forward_hook(self._get_head_hook()))

        if VitExtractor.PRE_LOGITS_KEY in self.return_nodes:
            self.hook_handlers.append(self.model.fc_norm.register_forward_hook(self._get_fc_norm_hook()))

    def _clear_hooks(self):
        for handler in self.hook_handlers:
            handler.remove()
        self.hook_handlers = []

    def _init_hooks_data(self):
        for key in VitExtractor.KEY_LIST:
            self.outputs_dict[key] = []

    def _get_head_hook(self):
        def _get_head_output(model, input, output):
            self.outputs_dict[VitExtractor.HEAD_KEY].append(output)

        return _get_head_output

    def _get_fc_norm_hook(self):
        def _get_fc_norm_output(model, input, output):
            self.outputs_dict[VitExtractor.PRE_LOGITS_KEY].append(output)

        return _get_fc_norm_output

    def _get_block_hook(self, input_resolution):
        def _get_block_output(model, input, output):
            B, N, C = input[0].shape
            self.outputs_dict[VitExtractor.BLOCK_KEY].append(output[:, 1:, :].reshape(B, *input_resolution, C))

        return _get_block_output

    def _get_cls_token_hook(self, input_resolution):
        def _get_cls_token_output(model, input, output):
            B, N, C = input[0].shape
            self.outputs_dict[VitExtractor.CLS_TOKEN_KEY].append(output[:, 0, :])

        return _get_cls_token_output

    def _get_attn_hook(self):
        def _get_attn_output(model, inp, output):
            self.outputs_dict[VitExtractor.ATTN_KEY].append(output)

        return _get_attn_output

    def _get_qkv_hook(self, num_heads, input_resolution):
        def _get_qkv_output(model, inp, output):
            B, N, C = inp[0].shape
            self.outputs_dict[VitExtractor.QKV_KEY].append(output.reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)[:, :, :, 1:, :].reshape(3, B, num_heads, *input_resolution, C // num_heads))

        return _get_qkv_output

    # TODO: CHECK ATTN OUTPUT TUPLE
    def _get_patch_imd_hook(self):
        def _get_attn_output(model, inp, output):
            self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])

        return _get_attn_output

    def get_layer_feature(self):  # List([B, N, D])
        feature = self.outputs_dict[VitExtractor.LAYER_KEY]
        return feature

    def get_block_feature(self):  # List([B, N, D])
        feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
        return feature

    def get_cls_token_feature(self):  # List([B, N, D])
        feature = self.outputs_dict[VitExtractor.CLS_TOKEN_KEY]
        return feature

    def get_qkv_feature(self):
        feature = self.outputs_dict[VitExtractor.QKV_KEY]
        return feature

    def get_attn_feature(self):
        feature = self.outputs_dict[VitExtractor.ATTN_KEY]
        return feature

    def get_patch_size(self):
        return self.model.patch_size

    def get_width_patch_num(self, input_img_shape):
        b, c, h, w = input_img_shape
        patch_size = self.get_patch_size()
        return w // patch_size

    def get_height_patch_num(self, input_img_shape):
        b, c, h, w = input_img_shape
        patch_size = self.get_patch_size()
        return h // patch_size

    def get_patch_num(self, input_img_shape):
        patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
        return patch_num

    def get_queries_from_qkv(self, qkv):
        return qkv[0]

    def get_keys_from_qkv(self, qkv):
        return qkv[1]

    def get_values_from_qkv(self, qkv):
        return qkv[2]

    def get_keys(self, block_idx):
        qkv_features = self.get_qkv_feature()[self.return_nodes[VitExtractor.QKV_KEY].index(block_idx)]
        keys = self.get_keys_from_qkv(qkv_features)
        return keys

    def get_keys_self_sim(self, block_idx):
        keys = self.get_keys(block_idx)
        N, heads, h, w, d = keys.shape
        concatenated_keys = keys.view(N, heads, h * w, d).transpose(1, 2).reshape(N, h * w, heads * d)
        ssim_map = attn_cosine_sim2(concatenated_keys)
        return ssim_map.reshape(N, h, w, h, w)

    def get_keys_for_all(self):
        results = [self.get_keys(block_idx) for block_idx in self.return_nodes[VitExtractor.QKV_KEY]]
        return results

    def get_keys_self_sim_for_all(self):
        results = [self.get_keys_self_sim(block_idx) for block_idx in self.return_nodes[VitExtractor.QKV_KEY]]
        return results

    def get_block_for_layer(self, block_idx):
        block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
        return block_features

    def get_block_for_all(self):
        results = [self.get_block_for_layer(block_idx) for block_idx in self.return_nodes[VitExtractor.BLOCK_KEY]]
        return results

    def get_prediction(self):
        return self.outputs_dict[VitExtractor.HEAD_KEY][0]

    def get_prelogits(self):
        return self.outputs_dict[VitExtractor.PRE_LOGITS_KEY][0]

    def get_cls_token_for_layer(self, block_idx):
        block_features = self.get_cls_token_feature()[self.return_nodes[VitExtractor.CLS_TOKEN_KEY].index(block_idx)]
        return block_features

I guess the problem is rooting at this extractor, can someone give me some ideas to redesign this extractor to satisfying the requirements of torch.compile()?

Hard to say without a full repro, where are you using torch.compile()? if you can’t share a full repro can you please try out the minifier which should help produce a minimal repro PyTorch 2.0 Troubleshooting — PyTorch master documentation