My custom function, which runs before every conv layer and only contains one addition and division operation, slowed inference about 2x the original inference time.
I want to know how to accelerate it as fast as I can.
Thank you
example
from torchvision import models
import torch
def zscore_dr_hook(module: torch.nn.Module, data: tuple) -> None:
if torch.any((module.weight-0.2)/0.4 > 1000):
pass
res = models.resnet50(pretrained=True).to('cuda')
rand_input = torch.randn(8, 3, 224, 224, device='cuda:0')
for i in range(100):
res(rand_input)
keys = ['conv1.weight',
'layer1.0.conv1.weight',
'layer1.0.conv2.weight',
'layer1.0.conv3.weight',
'layer1.0.downsample.0.weight',
'layer1.1.conv1.weight',
'layer1.1.conv2.weight',
'layer1.1.conv3.weight',
'layer1.2.conv1.weight',
'layer1.2.conv2.weight',
'layer1.2.conv3.weight',
'layer2.0.conv1.weight',
'layer2.0.conv2.weight',
'layer2.0.conv3.weight',
'layer2.0.downsample.0.weight',
'layer2.1.conv1.weight',
'layer2.1.conv2.weight',
'layer2.1.conv3.weight',
'layer2.2.conv1.weight',
'layer2.2.conv2.weight',
'layer2.2.conv3.weight',
'layer2.3.conv1.weight',
'layer2.3.conv2.weight',
'layer2.3.conv3.weight',
'layer3.0.conv1.weight',
'layer3.0.conv2.weight',
'layer3.0.conv3.weight',
'layer3.0.downsample.0.weight',
'layer3.1.conv1.weight',
'layer3.1.conv2.weight',
'layer3.1.conv3.weight',
'layer3.2.conv1.weight',
'layer3.2.conv2.weight',
'layer3.2.conv3.weight',
'layer3.3.conv1.weight',
'layer3.3.conv2.weight',
'layer3.3.conv3.weight',
'layer3.4.conv1.weight',
'layer3.4.conv2.weight',
'layer3.4.conv3.weight',
'layer3.5.conv1.weight',
'layer3.5.conv2.weight',
'layer3.5.conv3.weight',
'layer4.0.conv1.weight',
'layer4.0.conv2.weight',
'layer4.0.conv3.weight',
'layer4.0.downsample.0.weight',
'layer4.1.conv1.weight',
'layer4.1.conv2.weight',
'layer4.1.conv3.weight',
'layer4.2.conv1.weight',
'layer4.2.conv2.weight',
'layer4.2.conv3.weight',
'fc.weight']
handles = []
def register_hook(
model, hook
) -> None:
for key in keys:
key = key.rsplit(".", 1)[0]
module = model.get_submodule(key)
handles.append(module.register_forward_pre_hook(hook=hook))
register_hook(res, zscore_dr_hook)
def remove_hook(model) -> None:
[i.remove() for i in handles]