Trace all matrix vector multiplication in a nn.Module

Say I have a noisy matrix-vector multiplication (MVM) operation enabled by some novel hardware. I want to simulate how this noisy MVM will impact a nn.Module inference. How can I trace and then replace all MVMs in the given nn.Module?

nn.Linear perhaps is the easiest to deal with. But attention and Conv layers are also based on MVMs. I wonder if there is a general way that can deal with any layers?

Additionally, sometimes I guess a nn.Module might directly use torch.matmul or nn.funcitonal.linear instead of defining a nn.Linear module. How can I trace these operations?

class CustomNN(nn.Module):
    def __init__(self, input_dim, linear_dim, attention_dim, conv_in_channels, conv_out_channels, conv_kernel_size):
        super(CustomNN, self).__init__()
        self.linear = nn.Linear(input_dim, linear_dim)
        self.attention = Attention(linear_dim, attention_dim)
        self.conv = nn.Conv2d(conv_in_channels, conv_out_channels, conv_kernel_size)
        self.fc = nn.Linear(conv_out_channels, 1)
        
    def forward(self, x):
        x = self.linear(x)
        x, attention_weights = self.attention(x)
        x = x.view(x.size(0), -1, 1, 1)
        x = F.relu(self.conv(x))
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x, attention_weights

model = CustomNN(input_dim, linear_dim, attention_dim, conv_in_channels, conv_out_channels, conv_kernel_size)

traced_model = trace(model)

# Will run MVMs performed with error. For simplicity, say here the MVM error
# is just MVM(A,B) = A*B + Gauss(0,1). How to implement the trace()?
# I want it to be able to track the MVMs of the 'self.linear, self.attention, self.fc' 
# in the CustomNN. More generally, trace() should ideally be able to 
# deal with any layer types which use MVMs inside.
traced_model(x) 

Thanks

Hi @GoldenalCheese,

Can you give a minimal reproducible example in pytorch and show the desired output/behavior?