Hi, I’m testing new pytorch 2.0 feature torch.compile()
for my use case. However, when I use it naively, I encounter following errors:
Traceback (most recent call last):
File "/home/sist/luoxin/projects/BasicSR/basicsr/train.py", line 272, in <module>
train_pipeline(root_path)
File "/home/sist/luoxin/projects/BasicSR/basicsr/train.py", line 189, in train_pipeline
model.optimize_parameters(current_iter)
File "/home/sist/luoxin/projects/BasicSR/basicsr/models/perceptual_score_model.py", line 97, in optimize_parameters
d0 = self.net(self.p0, self.ref)
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 219, in forward
x_features = self.net(x)
File "/home/sist/luoxin/.conda/envs/py3.10+pytorch2.0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 104, in forward
features = self.forward_features(x)
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/perceptual_similarity_arch.py", line 79, in forward_features
feature = self.extractor.get_block_for_layer(layer)
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in get_block_for_layer
block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in <graph break in get_block_for_layer>
block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
File "/home/sist/luoxin/projects/BasicSR/basicsr/archs/extractor.py", line 208, in <graph break in get_block_for_layer>
block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
IndexError: list index out of range
Actually, I use a feature extractor for vision Transformer, the code is below:
class VitExtractor(torch.nn.Module):
BLOCK_KEY = 'block'
ATTN_KEY = 'attn'
PATCH_IMD_KEY = 'patch_imd'
QKV_KEY = 'qkv'
HEAD_KEY = 'head'
PRE_LOGITS_KEY = 'pre_logits'
CLS_TOKEN_KEY = 'cls_token'
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY, HEAD_KEY, PRE_LOGITS_KEY, CLS_TOKEN_KEY]
def __init__(self, model, return_nodes={}, range_norm=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
super().__init__()
if VitExtractor.HEAD_KEY in return_nodes:
return_nodes[VitExtractor.HEAD_KEY] = [len(model.blocks)]
max_layer = max([max(return_nodes[key]) for key in return_nodes.keys()])
self.model = remove_redundant_blocks(model, max_layer, model_type='vit')
self.return_nodes = return_nodes
self.use_input_norm = False
self.range_norm = range_norm
self.hook_handlers = []
self.outputs_dict = {}
for key in VitExtractor.KEY_LIST:
self.outputs_dict[key] = []
if key not in self.return_nodes:
self.return_nodes[key] = []
self._register_hooks()
if mean is not None:
self.use_input_norm = True
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor(mean).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor(std).view(1, 3, 1, 1))
def forward(self, x):
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
return self.model.forward(x)
def _register_hooks(self, **kwargs):
img_size = to_2tuple(self.model.patch_embed.img_size)
patch_size = to_2tuple(self.model.patch_embed.patch_size)
resolution = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
for block_idx, block in enumerate(self.model.blocks):
if block_idx in self.return_nodes[VitExtractor.BLOCK_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook(resolution)))
if block_idx in self.return_nodes[VitExtractor.ATTN_KEY]:
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
if block_idx in self.return_nodes[VitExtractor.QKV_KEY]:
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook(block.attn.num_heads, resolution)))
if block_idx in self.return_nodes[VitExtractor.PATCH_IMD_KEY]:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
if block_idx in self.return_nodes[VitExtractor.CLS_TOKEN_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_cls_token_hook()))
if VitExtractor.HEAD_KEY in self.return_nodes:
self.hook_handlers.append(self.model.head.register_forward_hook(self._get_head_hook()))
if VitExtractor.PRE_LOGITS_KEY in self.return_nodes:
self.hook_handlers.append(self.model.fc_norm.register_forward_hook(self._get_fc_norm_hook()))
def _clear_hooks(self):
for handler in self.hook_handlers:
handler.remove()
self.hook_handlers = []
def _init_hooks_data(self):
for key in VitExtractor.KEY_LIST:
self.outputs_dict[key] = []
def _get_head_hook(self):
def _get_head_output(model, input, output):
self.outputs_dict[VitExtractor.HEAD_KEY].append(output)
return _get_head_output
def _get_fc_norm_hook(self):
def _get_fc_norm_output(model, input, output):
self.outputs_dict[VitExtractor.PRE_LOGITS_KEY].append(output)
return _get_fc_norm_output
def _get_block_hook(self, input_resolution):
def _get_block_output(model, input, output):
B, N, C = input[0].shape
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output[:, 1:, :].reshape(B, *input_resolution, C))
return _get_block_output
def _get_cls_token_hook(self, input_resolution):
def _get_cls_token_output(model, input, output):
B, N, C = input[0].shape
self.outputs_dict[VitExtractor.CLS_TOKEN_KEY].append(output[:, 0, :])
return _get_cls_token_output
def _get_attn_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
return _get_attn_output
def _get_qkv_hook(self, num_heads, input_resolution):
def _get_qkv_output(model, inp, output):
B, N, C = inp[0].shape
self.outputs_dict[VitExtractor.QKV_KEY].append(output.reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)[:, :, :, 1:, :].reshape(3, B, num_heads, *input_resolution, C // num_heads))
return _get_qkv_output
# TODO: CHECK ATTN OUTPUT TUPLE
def _get_patch_imd_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
return _get_attn_output
def get_layer_feature(self): # List([B, N, D])
feature = self.outputs_dict[VitExtractor.LAYER_KEY]
return feature
def get_block_feature(self): # List([B, N, D])
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
return feature
def get_cls_token_feature(self): # List([B, N, D])
feature = self.outputs_dict[VitExtractor.CLS_TOKEN_KEY]
return feature
def get_qkv_feature(self):
feature = self.outputs_dict[VitExtractor.QKV_KEY]
return feature
def get_attn_feature(self):
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
return feature
def get_patch_size(self):
return self.model.patch_size
def get_width_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return w // patch_size
def get_height_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return h // patch_size
def get_patch_num(self, input_img_shape):
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
return patch_num
def get_queries_from_qkv(self, qkv):
return qkv[0]
def get_keys_from_qkv(self, qkv):
return qkv[1]
def get_values_from_qkv(self, qkv):
return qkv[2]
def get_keys(self, block_idx):
qkv_features = self.get_qkv_feature()[self.return_nodes[VitExtractor.QKV_KEY].index(block_idx)]
keys = self.get_keys_from_qkv(qkv_features)
return keys
def get_keys_self_sim(self, block_idx):
keys = self.get_keys(block_idx)
N, heads, h, w, d = keys.shape
concatenated_keys = keys.view(N, heads, h * w, d).transpose(1, 2).reshape(N, h * w, heads * d)
ssim_map = attn_cosine_sim2(concatenated_keys)
return ssim_map.reshape(N, h, w, h, w)
def get_keys_for_all(self):
results = [self.get_keys(block_idx) for block_idx in self.return_nodes[VitExtractor.QKV_KEY]]
return results
def get_keys_self_sim_for_all(self):
results = [self.get_keys_self_sim(block_idx) for block_idx in self.return_nodes[VitExtractor.QKV_KEY]]
return results
def get_block_for_layer(self, block_idx):
block_features = self.get_block_feature()[self.return_nodes[VitExtractor.BLOCK_KEY].index(block_idx)]
return block_features
def get_block_for_all(self):
results = [self.get_block_for_layer(block_idx) for block_idx in self.return_nodes[VitExtractor.BLOCK_KEY]]
return results
def get_prediction(self):
return self.outputs_dict[VitExtractor.HEAD_KEY][0]
def get_prelogits(self):
return self.outputs_dict[VitExtractor.PRE_LOGITS_KEY][0]
def get_cls_token_for_layer(self, block_idx):
block_features = self.get_cls_token_feature()[self.return_nodes[VitExtractor.CLS_TOKEN_KEY].index(block_idx)]
return block_features
I guess the problem is rooting at this extractor, can someone give me some ideas to redesign this extractor to satisfying the requirements of torch.compile()?