I saved the quantized weight and loaded it with the model after torch.ao.quantization.convert(). how do I print the output of each layer of the network?

After going through torch.ao.quantization.convert(), adding printing in the original network code seems to have no effect.

torch.ao.quantization.convert is eager mode quantization meaning it fully swaps the module to a quantized version, if you put print into the module that gets swapped, it wont do anything once its gone.

we usually do something like:

# intermediate layer analysis with forward hooks                        
                                                                        
activation = {}                                                         
def get_activation(name):                                               
    def hook(model, input, output):                                     
        activation[name] = {'input': input, 'output': output}           
    return hook                                                         
                                                                        
def add_hooks(m):                                                       
    for k, v in m.named_modules():                                      
        print(k, v)                                                     
        v.register_forward_hook(get_activation(k))                      
                                                                        
m = nn.Sequential(nn.Sequential(nn.Conv2d(1, 1, 1)), nn.Conv2d(1, 1, 1))
add_hooks(m)                                                            
m(torch.randn(1, 1, 1, 1))                                              
print(activation)

Absolutely forward_hook is the simplest solution. I’m curious if it’s applicable on the quantized model as well. At least I’ve tried the above code snippet but after the quantization flow apply

add_hook(model_int8)

activation doesn’t hook anything. Is there any officially recommended way to extract intermediate activations from the quantized model?

when you run the above snipped it doesn’t hook anything? Because i use this regularly to capture intermediate activations for quantized models.

You are right, the code works.
In my last reply it was not working for me because my q_model is scripted, I think.

yeah hooks are module oriented so scripting which removes modules would break it